diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py index 5345c70..2fe2a7f 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -49,7 +49,7 @@ def make_keras_optimizer_class(cls): ```python # Create optimizer. - opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, + opt = {dp_keras_class}(l2_norm_clip=1.0, noise_multiplier=0.5, num_microbatches=1, ) ``` @@ -81,6 +81,39 @@ def make_keras_optimizer_class(cls): model.fit(...) ``` + In DP-SGD training, a larger batch size typically helps to achieve better + privacy/utility tradeoff. However there is typically a maximum batch size + imposed by hardware. + This optimizer can emulate large batch sizes on hardware with limited + memory by accumulating gradients for several steps before actually + applying them to update model weights. + Constructor argument `gradient_accumulation_steps` controls the number + of steps for which gradients are accumulated before updating + the model weights. + + Below is an example which demonstrates how to use this feature: + + ```python + # Create optimizer which will be accumulating gradients for 4 steps. + # and then performing an update of model weights. + opt = {dp_keras_class}(l2_norm_clip=1.0, + noise_multiplier=0.5, + num_microbatches=1, + gradient_accumulation_steps=4, + ) + + # Use optimizer in a regular way. + # First three calls to opt.minimize won't update model weights and will + # only accumulate gradients. Model weights will be updated on the fourth + # call to opt.minimize + opt.minimize(loss, var_list=[var]) + ``` + + Note that when using this feature effective batch size is + `gradient_accumulation_steps * one_step_batch_size` where + `one_step_batch_size` size of the batch which is passed to single step + of the optimizer. Thus user may have to adjust learning rate, weight decay + and possibly other training hyperparameters accordingly. """.format( base_class='tf.keras.optimizers.' + cls.__name__, short_base_class=cls.__name__, @@ -100,6 +133,7 @@ def make_keras_optimizer_class(cls): l2_norm_clip, noise_multiplier, num_microbatches=None, + gradient_accumulation_steps=1, *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args **kwargs): """Initialize the DPOptimizerClass. @@ -108,11 +142,21 @@ def make_keras_optimizer_class(cls): l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients). noise_multiplier: Ratio of the standard deviation to the clipping norm. num_microbatches: Number of microbatches into which each minibatch is - split. + split. Default is `None` which means that number of microbatches + is equal to batch size (i.e. each microbatch contains exactly one + example). If `gradient_accumulation_steps` is greater than 1 and + `num_microbatches` is not `None` then the effective number of + microbatches is equal to + `num_microbatches * gradient_accumulation_steps`. + gradient_accumulation_steps: If greater than 1 then optimizer will be + accumulating gradients for this number of optimizer steps before + applying them to update model weights. If this argument is set to 1 + then updates will be applied on each optimizer step. *args: These will be passed on to the base class `__init__` method. **kwargs: These will be passed on to the base class `__init__` method. """ super(DPOptimizerClass, self).__init__(*args, **kwargs) + self.gradient_accumulation_steps = gradient_accumulation_steps self._l2_norm_clip = l2_norm_clip self._noise_multiplier = noise_multiplier self._num_microbatches = num_microbatches @@ -121,6 +165,69 @@ def make_keras_optimizer_class(cls): self._global_state = None self._was_dp_gradients_called = False + def _create_slots(self, var_list): + super(DPOptimizerClass, self)._create_slots(var_list) + if self.gradient_accumulation_steps > 1: + for var in var_list: + self.add_slot(var, 'grad_acc') + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(DPOptimizerClass, self)._prepare_local( + var_device, var_dtype, apply_state) + if self.gradient_accumulation_steps > 1: + apply_update = tf.math.equal( + tf.math.floormod(self.iterations + 1, + self.gradient_accumulation_steps), + 0) + grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype) + apply_state[(var_device, var_dtype)].update( + { + 'apply_update': apply_update, + 'grad_scaler': grad_scaler + }) + + def _resource_apply_dense(self, grad, var, apply_state=None): + if self.gradient_accumulation_steps > 1: + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ((apply_state or {}).get((var_device, var_dtype)) + or self._fallback_apply_state(var_device, var_dtype)) + grad_acc = self.get_slot(var, 'grad_acc') + + def _update_grad(): + apply_grad_op = super(DPOptimizerClass, self)._resource_apply_dense( + grad_acc + grad * coefficients['grad_scaler'], var, apply_state) + with tf.control_dependencies([apply_grad_op]): + return grad_acc.assign(tf.zeros_like(grad_acc), + use_locking=self._use_locking, + read_value=False) + + def _accumulate(): + return grad_acc.assign_add(grad * coefficients['grad_scaler'], + use_locking=self._use_locking, + read_value=False) + + return tf.cond(coefficients['apply_update'], _update_grad, _accumulate) + else: + return super(DPOptimizerClass, self)._resource_apply_dense( + grad, var, apply_state) + + def _resource_apply_sparse_duplicate_indices(self, *args, **kwargs): + if self.gradient_accumulation_steps > 1: + raise NotImplementedError( + 'Sparse gradients are not supported with large batch emulation.') + else: + return super(DPOptimizerClass, + self)._resource_apply_sparse_duplicate_indices( + *args, **kwargs) + + def _resource_apply_sparse(self, *args, **kwargs): + if self.gradient_accumulation_steps > 1: + raise NotImplementedError( + 'Sparse gradients are not supported with large batch emulation.') + else: + return super(DPOptimizerClass, self)._resource_apply_sparse( + *args, **kwargs) + def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): """DP-SGD version of base class method.""" diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py index 71b68c6..b4013bf 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -394,6 +394,87 @@ class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): grads_and_vars = tf.Variable([0.0]) opt.apply_gradients(grads_and_vars) + def testLargeBatchEmulationNoNoise(self): + # Test for emulation of large batch training. + # It tests that updates are only done every gradient_accumulation_steps + # steps. + # In this test we set noise multiplier to zero and clipping norm to high + # value, such that optimizer essentially behave as non-DP optimizer. + # This makes easier to check how values of variables are changing. + # + # This test optimizes loss var0*x + var1 + # Gradients of this loss are computed as: + # d(loss)/d(var0) = x + # d(loss)/d(var1) = 1 + var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32) + var1 = tf.Variable([3.0], dtype=tf.float32) + x1 = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32) + loss1 = lambda: tf.matmul(var0, x1, transpose_b=True) + var1 + x2 = tf.constant([[4.0, 2.0], [2.0, 1.0]], dtype=tf.float32) + loss2 = lambda: tf.matmul(var0, x2, transpose_b=True) + var1 + + opt = dp_optimizer_keras.DPKerasSGDOptimizer( + l2_norm_clip=100.0, + noise_multiplier=0.0, + gradient_accumulation_steps=2, + learning_rate=1.0) + + # before any call to optimizer + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) + self.assertAllCloseAccordingToType([3.0], var1) + + opt.minimize(loss1, [var0, var1]) + # After first call to optimizer values didn't change + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) + self.assertAllCloseAccordingToType([3.0], var1) + + opt.minimize(loss2, [var0, var1]) + # After second call to optimizer updates were applied + self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) + self.assertAllCloseAccordingToType([2.0], var1) + + opt.minimize(loss2, [var0, var1]) + # After third call to optimizer values didn't change + self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) + self.assertAllCloseAccordingToType([2.0], var1) + + opt.minimize(loss2, [var0, var1]) + # After fourth call to optimizer updates were applied again + self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0) + self.assertAllCloseAccordingToType([1.0], var1) + + @parameterized.named_parameters( + ('DPKerasSGDOptimizer 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPKerasSGDOptimizer 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), + ('DPKerasSGDOptimizer 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4), + ('DPKerasAdamOptimizer 2', + dp_optimizer_keras.DPKerasAdamOptimizer, 1), + ('DPKerasAdagradOptimizer 2', + dp_optimizer_keras.DPKerasAdagradOptimizer, 2), + ) + def testLargeBatchEmulation(self, cls, gradient_accumulation_steps): + # Tests various optimizers with large batch emulation. + # Uses clipping and noise, thus does not test specific values + # of the variables and only tests how often variables are updated. + var0 = tf.Variable([[1.0, 2.0]], dtype=tf.float32) + var1 = tf.Variable([3.0], dtype=tf.float32) + x = tf.constant([[2.0, 0.0], [0.0, 1.0]], dtype=tf.float32) + loss = lambda: tf.matmul(var0, x, transpose_b=True) + var1 + + opt = cls( + l2_norm_clip=100.0, + noise_multiplier=0.0, + gradient_accumulation_steps=gradient_accumulation_steps, + learning_rate=1.0) + + for _ in range(gradient_accumulation_steps): + self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) + self.assertAllCloseAccordingToType([3.0], var1) + opt.minimize(loss, [var0, var1]) + + self.assertNotAllClose([[1.0, 2.0]], var0) + self.assertNotAllClose([3.0], var1) + if __name__ == '__main__': tf.test.main()