Add ability to specify number of microbatches in DPModel class.

PiperOrigin-RevId: 430358084
This commit is contained in:
Steve Chien 2022-02-22 20:31:42 -08:00 committed by A. Unique TensorFlower
parent bfdcb7f64f
commit a33afde0c1
2 changed files with 157 additions and 22 deletions

View file

@ -62,25 +62,37 @@ def make_dp_model_class(cls):
self,
l2_norm_clip,
noise_multiplier,
num_microbatches=None,
use_xla=True,
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
**kwargs):
"""Initializes the DPModelClass.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
gradients).
noise_multiplier: Ratio of the standard deviation to the clipping
norm.
use_xla: If `True`, compiles train_step to XLA.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__`
method.
Args:
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
gradients).
noise_multiplier: Ratio of the standard deviation to the clipping
norm.
num_microbatches: Number of microbatches.
use_xla: If `True`, compiles train_step to XLA.
*args: These will be passed on to the base class `__init__` method.
**kwargs: These will be passed on to the base class `__init__`
method.
"""
super().__init__(*args, **kwargs)
self._l2_norm_clip = l2_norm_clip
self._noise_multiplier = noise_multiplier
# Given that `num_microbatches` was added as an argument after the fact,
# this check helps detect unintended calls to the earlier API.
# In particular, boolean values supplied to `use_xla` in the earlier API
# will raise an error.
if isinstance(num_microbatches, bool):
raise ValueError('Boolean value supplied for `num_microbatches`. '
'Did you intend it for `use_xla`?')
self._num_microbatches = num_microbatches
if use_xla:
self.train_step = tf.function(
self.train_step, experimental_compile=True)
@ -106,21 +118,35 @@ def make_dp_model_class(cls):
def _compute_per_example_grads(self, data):
x, y = data
with tf.GradientTape() as tape:
# We need to add the extra dimension to x and y because model
# expects batched input.
y_pred = self(x[None], training=True)
loss = self.compiled_loss(
y[None], y_pred, regularization_losses=self.losses)
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
grads_list = tape.gradient(loss, self.trainable_variables)
clipped_grads = self._process_per_example_grads(grads_list)
return tf.squeeze(y_pred, axis=0), loss, clipped_grads
return y_pred, loss, clipped_grads
def train_step(self, data):
"""DP-SGD version of base class method."""
_, y = data
batch_size = y.shape[0]
if self._num_microbatches is None:
self._num_microbatches = batch_size
if batch_size % self._num_microbatches != 0:
raise ValueError('Number of_microbatches must divide batch size.')
def reshape_fn(x):
new_shape = (self._num_microbatches,
batch_size // self._num_microbatches) + x.shape[1:]
return tf.reshape(x, new_shape)
data = tf.nest.map_structure(reshape_fn, data)
y_pred, _, per_eg_grads = tf.vectorized_map(
self._compute_per_example_grads, data)
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
grads = tf.nest.map_structure(self._reduce_per_example_grads,
per_eg_grads)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))

View file

@ -92,11 +92,74 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(model_weights[0], expected_weights)
self.assertAllClose(model_weights[1], expected_bias)
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
num_microbatches):
batch_size = data.shape[0]
if num_microbatches is None:
num_microbatches = batch_size
preds = np.matmul(data, w)
grads = 2 * data * (labels - preds)[:, np.newaxis]
grads = np.reshape(grads,
[num_microbatches, batch_size // num_microbatches, -1])
mb_grads = np.mean(grads, axis=1)
mb_grad_norms = np.linalg.norm(mb_grads, axis=1)
scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0)
mb_grads = mb_grads * scale[:, np.newaxis]
final_grads = np.mean(mb_grads, axis=0)
return final_grads
@parameterized.named_parameters(
('noise_multiplier 3 2', 3.0, 2.0),
('noise_multiplier 5 4', 5.0, 4.0),
('mb_test 0', 1.0, None),
('mb_test 1', 1.0, 1),
('mb_test 2', 1.0, 2),
('mb_test 4', 1.0, 4),
)
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier):
def testMicrobatches(self, l2_norm_clip, num_microbatches):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
w = np.zeros((2))
train_labels = np.array([1.0, 3.0, -2.0, -4.0])
learning_rate = 1.0
expected_grads = self._compute_expected_gradients(train_data, train_labels,
w, l2_norm_clip,
num_microbatches)
expected_weights = np.squeeze(learning_rate * expected_grads)
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss = tf.keras.losses.MeanSquaredError()
# Simple linear model returns w * x + b.
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
1, use_bias=False, kernel_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
model_weights = np.squeeze(model.get_weights())
self.assertAllClose(model_weights, expected_weights)
@parameterized.named_parameters(
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
('noise_multiplier 5 4 1', 5.0, 4.0, 1),
('noise_multiplier 3 2 2', 3.0, 2.0, 2),
('noise_multiplier 5 4 2', 5.0, 4.0, 2),
('noise_multiplier 3 2 4', 3.0, 2.0, 4),
('noise_multiplier 5 4 4', 5.0, 4.0, 4),
)
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier,
num_microbatches):
# The idea behind this test is to start with a model whose parameters
# are set to zero. We then run one step of a model that produces
# an un-noised gradient of zero, and then compute the standard deviation
@ -104,8 +167,8 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
# deviation.
# Data is one example of length 1000, set to zero, with label zero.
train_data = np.zeros((1, 1000))
train_labels = np.array([0.0])
train_data = np.zeros((4, 1000))
train_labels = np.array([0.0, 0.0, 0.0, 0.0])
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
@ -115,21 +178,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
model = dp_keras_model.DPSequential(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(1000,)),
tf.keras.layers.Dense(
1, kernel_initializer='zeros', bias_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss)
model.fit(train_data, train_labels, epochs=1, batch_size=1)
model.fit(train_data, train_labels, epochs=1, batch_size=4)
model_weights = model.get_weights()
measured_std = np.std(model_weights[0])
expected_std = l2_norm_clip * noise_multiplier
expected_std = l2_norm_clip * noise_multiplier / num_microbatches
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
# Simple check to make sure dimensions are correct when output has
# dimension > 1.
@parameterized.named_parameters(
('mb_test None 1', None, 1),
('mb_test 1 2', 1, 2),
('mb_test 2 2', 2, 2),
('mb_test 4 4', 4, 4),
)
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
train_labels = np.array([0, 1, 1, 0])
learning_rate = 1.0
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model = dp_keras_model.DPSequential(
l2_norm_clip=1.0e9,
noise_multiplier=0.0,
num_microbatches=num_microbatches,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
output_dimension, use_bias=False, kernel_initializer='zeros')
])
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
# Checks that calls to earlier API using `use_xla` as a positional argument
# raise an exception.
@parameterized.named_parameters(
('earlier API True', True),
('earlier API False', False),
)
def testEarlierAPIFails(self, use_xla):
with self.assertRaises(ValueError):
_ = dp_keras_model.DPSequential(
1.0e9,
0.0,
use_xla,
layers=[
tf.keras.layers.InputLayer(input_shape=(2,)),
tf.keras.layers.Dense(
2, use_bias=False, kernel_initializer='zeros')
])
if __name__ == '__main__':
tf.test.main()