From 52806ba952207f74148c178ec260d8e8629f0126 Mon Sep 17 00:00:00 2001 From: Walid Krichene Date: Thu, 16 Mar 2023 12:35:19 -0700 Subject: [PATCH] In dp_optimizer_keras_sparse, update `iterations` to reflect the number of logical batches, rather than physical batches. In the current behavior, when using gradient accumulation, the `iterations` variable is incremented at every physical batch, while variables are only updated at every logical batch (where logical batch = accumulation_steps many physical batches). This causes certain optimizers that explicitly depend on `iterations` (such as Adam) to behave very differently under gradient accumulation. With this change, `iterations` is only incremented after each logical batch. PiperOrigin-RevId: 517197044 --- tensorflow_privacy/privacy/optimizers/BUILD | 9 ++ .../optimizers/dp_optimizer_keras_sparse.py | 72 +++++++-- ...optimizer_keras_sparse_distributed_test.py | 153 ++++++++++++++++++ .../dp_optimizer_keras_sparse_test.py | 33 ++-- 4 files changed, 247 insertions(+), 20 deletions(-) create mode 100644 tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_distributed_test.py diff --git a/tensorflow_privacy/privacy/optimizers/BUILD b/tensorflow_privacy/privacy/optimizers/BUILD index 166b3e5..3eba490 100644 --- a/tensorflow_privacy/privacy/optimizers/BUILD +++ b/tensorflow_privacy/privacy/optimizers/BUILD @@ -109,6 +109,15 @@ py_test( deps = [":dp_optimizer_keras_sparse"], ) +py_test( + name = "dp_optimizer_keras_sparse_distributed_test", + timeout = "long", + srcs = ["dp_optimizer_keras_sparse_distributed_test.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [":dp_optimizer_keras_sparse"], +) + py_test( name = "dp_optimizer_vectorized_test", timeout = "long", diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse.py index d8afe67..c3aa34c 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse.py @@ -185,6 +185,7 @@ def make_sparse_keras_optimizer_class(cls): self._num_microbatches = num_microbatches self._was_dp_gradients_called = False self._noise_stddev = None + self._acc_iterations = None if self._num_microbatches is not None: # The loss/gradients is the mean over the microbatches so we # divide the noise by num_microbatches too to obtain the correct @@ -202,15 +203,29 @@ def make_sparse_keras_optimizer_class(cls): def _create_slots(self, var_list): super()._create_slots(var_list) # pytype: disable=attribute-error if self.gradient_accumulation_steps > 1: + # Slots for accumulating gradients. for var in var_list: self.add_slot(var, 'grad_acc') + if self._acc_iterations is None: + # Variable for the iterations, used for bookkeeping when to accumulate + # and when to update. + self._acc_iterations = self.add_weight( + 'acc_iterations', + shape=[], + trainable=False, + dtype=tf.int64, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) def _prepare_local(self, var_device, var_dtype, apply_state): super()._prepare_local(var_device, var_dtype, apply_state) # pytype: disable=attribute-error if self.gradient_accumulation_steps > 1: apply_update = tf.math.equal( - tf.math.floormod(self.iterations + 1, - self.gradient_accumulation_steps), 0) + tf.math.floormod( + self._acc_iterations + 1, self.gradient_accumulation_steps + ), + 0, + ) grad_scaler = tf.cast(1. / self.gradient_accumulation_steps, var_dtype) apply_state[(var_device, var_dtype)].update({ 'apply_update': apply_update, @@ -218,7 +233,7 @@ def make_sparse_keras_optimizer_class(cls): }) def _resource_apply(self, accum_op, grad, var, apply_state=None): - """Help method for _resource_apply_dense and _resource_apply_sparse.""" + """Helper method for _resource_apply_dense and _resource_apply_sparse.""" if self.gradient_accumulation_steps > 1: var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or @@ -235,9 +250,8 @@ def make_sparse_keras_optimizer_class(cls): tf.zeros_like(grad_acc), use_locking=self._use_locking, read_value=False) - accum_op(grad_acc, grad, use_locking=self._use_locking) - return tf.cond( - coefficients['apply_update'], _update_grad, lambda: tf.no_op()) # pylint: disable=unnecessary-lambda + with tf.control_dependencies([accum_op(grad_acc, grad)]): + return tf.cond(coefficients['apply_update'], _update_grad, tf.no_op) else: grad = tf.convert_to_tensor(grad) grad = grad + self._generate_noise(grad) @@ -246,9 +260,11 @@ def make_sparse_keras_optimizer_class(cls): def _resource_apply_dense(self, grad, var, apply_state=None): """Handles dense gradients.""" - def _accum_op(grad_acc, grad, use_locking): + def _accum_op(grad_acc, grad): return grad_acc.assign_add( - grad, use_locking=use_locking, read_value=False) + grad, use_locking=self._use_locking, read_value=False + ) + return self._resource_apply(_accum_op, grad, var, apply_state) # This method is implemented the same as that in optimizer_v2.py. We @@ -271,13 +287,49 @@ def make_sparse_keras_optimizer_class(cls): def _resource_apply_sparse(self, grad, var, indices, apply_state=None): """Handles deduped sparse gradients.""" - def _accum_op(grad_acc, sparse_delta, use_locking): + def _accum_op(grad_acc, sparse_delta): return grad_acc.scatter_add( - sparse_delta=sparse_delta, use_locking=use_locking) + sparse_delta=sparse_delta, use_locking=self._use_locking + ) + sparse_delta = tf.IndexedSlices( values=grad, indices=indices, dense_shape=var.shape) return self._resource_apply(_accum_op, sparse_delta, var, apply_state) + def _distributed_apply( + self, distribution, grads_and_vars, apply_state, name + ): + apply_op = super()._distributed_apply( + distribution, grads_and_vars, apply_state, name + ) + if self.gradient_accumulation_steps > 1: + # The original _distributed_apply increments self.iterations after each + # call. But we want to increment it only after each logical batch is + # processed, so optimizers that explicitly use self.iterations in their + # updates (such as Adam) can use the correct value. + def increment_acc_iterations(): + # Always use locking when updating the steps, so we don't under-count + # the steps (which could invalidate privacy accounting). + return self._acc_iterations.assign_add( + 1, use_locking=True, read_value=False + ) + + def assign_iterations(): + return self.iterations.assign( + tf.math.floordiv( + self._acc_iterations, self.gradient_accumulation_steps + ), + use_locking=True, + read_value=False, + ) + + with tf.control_dependencies([apply_op]): + with tf.control_dependencies([increment_acc_iterations()]): + return assign_iterations() + else: + # No accumulation. + return apply_op + def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): """DP-SGD version of base class method.""" self._was_dp_gradients_called = True diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_distributed_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_distributed_test.py new file mode 100644 index 0000000..b178494 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_distributed_test.py @@ -0,0 +1,153 @@ +# Copyright 2023, The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests DPSparseKerasSGDOptimizer in distributed training.""" +import contextlib +import multiprocessing +import os +import sys +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras_sparse + +ds_combinations = tf.__internal__.distribute.combinations + + +STRATEGIES = [ + ds_combinations.one_device_strategy, + ds_combinations.parameter_server_strategy_1worker_2ps_cpu, +] + + +class DistributedTrainingTest(parameterized.TestCase, tf.test.TestCase): + + @ds_combinations.generate( + tf.__internal__.test.combinations.combine( + strategy=STRATEGIES, mode="eager" + ) + ) + def test_training_works(self, strategy): + if isinstance(strategy, tf.distribute.OneDeviceStrategy): + strategy_scope = contextlib.nullcontext() + else: + strategy_scope = strategy.scope() + + def make_model(): + inputs = tf.keras.Input((1000,)) + dense = tf.keras.layers.Dense( + units=1, use_bias=False, kernel_initializer=tf.initializers.ones() + ) + outputs = dense(inputs) + return tf.keras.models.Model(inputs=inputs, outputs=outputs) + + x = tf.ones(shape=[5000, 1000]) + y = tf.zeros(shape=[5000]) + with strategy_scope: + model = make_model() + clip = 100.0 + noise_mult = 0.01 + acc_steps = 5 + batch_size = 10 + opt = dp_optimizer_keras_sparse.DPSparseKerasSGDOptimizer( + l2_norm_clip=clip, + noise_multiplier=noise_mult, + gradient_accumulation_steps=acc_steps, + learning_rate=0.001, + ) + model.compile( + loss=tf.keras.losses.MeanAbsoluteError( + reduction=tf.keras.losses.Reduction.NONE + ), + optimizer=opt, + ) + history = model.fit( + x=x, + y=y, + epochs=2, + steps_per_epoch=500, + batch_size=batch_size, + ) + self.assertIn("loss", history.history) + # total steps: 1000 (2 epochs, 500 steps/epoch) + # accumulation steps: 5 + # expected_iterations = total steps / accumulation steps + expected_iterations = 1000 / acc_steps # = 200 + # The loss is |w.x - y| (where w is the dense layer weights). + # The gradient is sign(w.x - y)x. With the choice of x, y, the gradient + # becomes x. + # So each gradient update is w <- w - learning_rate*1 + noise + expected_params = 1 - 0.001 * expected_iterations + expected_noise = ( + 0.001 + * clip + * noise_mult + * np.sqrt(expected_iterations) + / (acc_steps * batch_size) + ) + self.assertEqual(opt.iterations.numpy(), expected_iterations) + self.assertAllClose( + np.mean(model.trainable_variables[0].numpy()), + expected_params, # 0.8 + # stddev = expected_noise/√1000 (since we're averaging 1000 samples) + # we set atol to 4 stddev + atol=4 * expected_noise / np.sqrt(1000), # 0.0358 + ) + self.assertAllClose( + np.std(model.trainable_variables[0].numpy()), + expected_noise, # 0.2828 + atol=4 * expected_noise / np.sqrt(1000), # 0.0358 + ) + + +def _set_spawn_exe_path(): + """Set the path to the executable for spawned processes. + + This utility searches for the binary the parent process is using, and sets + the executable of multiprocessing's context accordingly. + It is branched from tensorflow/python/distribute/multi_process_lib.py, the + only change being that it additionally looks under "org_tensorflow_privacy". + """ + if sys.argv[0].endswith(".py"): + + def guess_path(package_root): + # If all we have is a python module path, we'll need to make a guess for + # the actual executable path. + if "bazel-out" in sys.argv[0] and package_root in sys.argv[0]: + # Guess the binary path under bazel. For target + # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the + # argv[0] is in the form of + # /.../tensorflow/python/distribute/input_lib_test.py + # and the binary is + # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu + package_root_base = sys.argv[0][: sys.argv[0].rfind(package_root)] + binary = os.environ["TEST_TARGET"][2:].replace(":", "/", 1) + possible_path = os.path.join(package_root_base, package_root, binary) + if os.access(possible_path, os.X_OK): + return possible_path + return None + + path = ( + guess_path("org_tensorflow") + or guess_path("org_keras") + or guess_path("org_tensorflow_privacy") + ) + if path is not None: + sys.argv[0] = path + multiprocessing.get_context().set_executable(sys.argv[0]) + + +if __name__ == "__main__": + _set_spawn_exe_path() + tf.__internal__.distribute.multi_process_runner.test_main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_test.py index ce67f87..182d647 100644 --- a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_test.py +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_sparse_test.py @@ -338,21 +338,25 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): # After first call to optimizer values didn't change self.assertAllCloseAccordingToType([[1.0, 2.0]], var0) self.assertAllCloseAccordingToType([3.0], var1) + self.assertEqual(opt.iterations, 0) opt.minimize(loss2, [var0, var1]) # After second call to optimizer updates were applied self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([2.0], var1) + self.assertEqual(opt.iterations, 1) opt.minimize(loss2, [var0, var1]) # After third call to optimizer values didn't change self.assertAllCloseAccordingToType([[-1.0, 1.0]], var0) self.assertAllCloseAccordingToType([2.0], var1) + self.assertEqual(opt.iterations, 1) opt.minimize(loss2, [var0, var1]) # After fourth call to optimizer updates were applied again self.assertAllCloseAccordingToType([[-4.0, -0.5]], var0) self.assertAllCloseAccordingToType([1.0], var1) + self.assertEqual(opt.iterations, 2) @parameterized.named_parameters( ('DPSparseKerasSGDOptimizer 1', @@ -388,6 +392,7 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): self.assertNotAllClose([[1.0, 2.0]], var0) self.assertNotAllClose([3.0], var1) + self.assertEqual(opt.iterations, 1) def testKerasModelBaselineSaving(self): """Tests that DP optimizers work with tf.keras.Model.""" @@ -455,10 +460,15 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False) - @parameterized.named_parameters(('1', 1), ('None', None)) - def testKerasModelBaselineNoNoise(self, num_microbatches): + @parameterized.named_parameters( + ('no_microbatch_no_accumulation', False, False), + ('no_microbatch_accumulation', False, True), + ('microbatch_no_accumulation', True, False), + ('microbatch_accumulation', True, True), + ) + def testKerasModelBaselineNoNoise(self, microbatch, accumulate): """Tests that DP optimizers work with tf.keras.Model.""" - + acc_steps = 2 if accumulate else 1 model = tf.keras.models.Sequential(layers=[ tf.keras.layers.Dense( 1, @@ -471,22 +481,25 @@ class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): optimizer = dp_optimizer.DPSparseKerasSGDOptimizer( l2_norm_clip=100.0, noise_multiplier=0.0, - num_microbatches=num_microbatches, - learning_rate=0.05) + num_microbatches=None if microbatch else 1, + learning_rate=0.05, + gradient_accumulation_steps=acc_steps, + ) loss = tf.keras.losses.MeanSquaredError(reduction='none') model.compile(optimizer, loss) true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) true_bias = np.array([6.0]).astype(np.float32) - train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32) - train_labels = np.matmul(train_data, - true_weights) + true_bias + np.random.normal( - scale=0.0, size=(1000, 1)).astype(np.float32) + train_data = np.random.normal(scale=3.0, size=(2000, 4)).astype(np.float32) + train_labels = np.matmul(train_data, true_weights) + true_bias - model.fit(train_data, train_labels, batch_size=8, epochs=1, shuffle=False) + model.fit(train_data, train_labels, batch_size=10, epochs=1, shuffle=False) self.assertAllClose(model.get_weights()[0], true_weights, atol=0.05) self.assertAllClose(model.get_weights()[1], true_bias, atol=0.05) + # Check that the optimizer's iterations equal the number of logical batches. + total_batches = 200 + self.assertEqual(optimizer.iterations.numpy(), total_batches / acc_steps) if __name__ == '__main__':