forked from 626_privacy/tensorflow_privacy
Add ability to specify number of microbatches in DPModel
class.
PiperOrigin-RevId: 430358084
This commit is contained in:
parent
bfdcb7f64f
commit
a33afde0c1
2 changed files with 157 additions and 22 deletions
|
@ -62,25 +62,37 @@ def make_dp_model_class(cls):
|
||||||
self,
|
self,
|
||||||
l2_norm_clip,
|
l2_norm_clip,
|
||||||
noise_multiplier,
|
noise_multiplier,
|
||||||
|
num_microbatches=None,
|
||||||
use_xla=True,
|
use_xla=True,
|
||||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Initializes the DPModelClass.
|
"""Initializes the DPModelClass.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
|
l2_norm_clip: Clipping norm (max L2 norm of per microbatch
|
||||||
gradients).
|
gradients).
|
||||||
noise_multiplier: Ratio of the standard deviation to the clipping
|
noise_multiplier: Ratio of the standard deviation to the clipping
|
||||||
norm.
|
norm.
|
||||||
use_xla: If `True`, compiles train_step to XLA.
|
num_microbatches: Number of microbatches.
|
||||||
*args: These will be passed on to the base class `__init__` method.
|
use_xla: If `True`, compiles train_step to XLA.
|
||||||
**kwargs: These will be passed on to the base class `__init__`
|
*args: These will be passed on to the base class `__init__` method.
|
||||||
method.
|
**kwargs: These will be passed on to the base class `__init__`
|
||||||
|
method.
|
||||||
"""
|
"""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._l2_norm_clip = l2_norm_clip
|
self._l2_norm_clip = l2_norm_clip
|
||||||
self._noise_multiplier = noise_multiplier
|
self._noise_multiplier = noise_multiplier
|
||||||
|
|
||||||
|
# Given that `num_microbatches` was added as an argument after the fact,
|
||||||
|
# this check helps detect unintended calls to the earlier API.
|
||||||
|
# In particular, boolean values supplied to `use_xla` in the earlier API
|
||||||
|
# will raise an error.
|
||||||
|
if isinstance(num_microbatches, bool):
|
||||||
|
raise ValueError('Boolean value supplied for `num_microbatches`. '
|
||||||
|
'Did you intend it for `use_xla`?')
|
||||||
|
|
||||||
|
self._num_microbatches = num_microbatches
|
||||||
|
|
||||||
if use_xla:
|
if use_xla:
|
||||||
self.train_step = tf.function(
|
self.train_step = tf.function(
|
||||||
self.train_step, experimental_compile=True)
|
self.train_step, experimental_compile=True)
|
||||||
|
@ -106,21 +118,35 @@ def make_dp_model_class(cls):
|
||||||
def _compute_per_example_grads(self, data):
|
def _compute_per_example_grads(self, data):
|
||||||
x, y = data
|
x, y = data
|
||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
# We need to add the extra dimension to x and y because model
|
y_pred = self(x, training=True)
|
||||||
# expects batched input.
|
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
|
||||||
y_pred = self(x[None], training=True)
|
|
||||||
loss = self.compiled_loss(
|
|
||||||
y[None], y_pred, regularization_losses=self.losses)
|
|
||||||
|
|
||||||
grads_list = tape.gradient(loss, self.trainable_variables)
|
grads_list = tape.gradient(loss, self.trainable_variables)
|
||||||
clipped_grads = self._process_per_example_grads(grads_list)
|
clipped_grads = self._process_per_example_grads(grads_list)
|
||||||
return tf.squeeze(y_pred, axis=0), loss, clipped_grads
|
return y_pred, loss, clipped_grads
|
||||||
|
|
||||||
def train_step(self, data):
|
def train_step(self, data):
|
||||||
"""DP-SGD version of base class method."""
|
"""DP-SGD version of base class method."""
|
||||||
_, y = data
|
_, y = data
|
||||||
|
batch_size = y.shape[0]
|
||||||
|
|
||||||
|
if self._num_microbatches is None:
|
||||||
|
self._num_microbatches = batch_size
|
||||||
|
if batch_size % self._num_microbatches != 0:
|
||||||
|
raise ValueError('Number of_microbatches must divide batch size.')
|
||||||
|
|
||||||
|
def reshape_fn(x):
|
||||||
|
new_shape = (self._num_microbatches,
|
||||||
|
batch_size // self._num_microbatches) + x.shape[1:]
|
||||||
|
return tf.reshape(x, new_shape)
|
||||||
|
|
||||||
|
data = tf.nest.map_structure(reshape_fn, data)
|
||||||
|
|
||||||
y_pred, _, per_eg_grads = tf.vectorized_map(
|
y_pred, _, per_eg_grads = tf.vectorized_map(
|
||||||
self._compute_per_example_grads, data)
|
self._compute_per_example_grads, data)
|
||||||
|
|
||||||
|
y_pred = tf.reshape(y_pred, (batch_size) + y_pred.shape[2:])
|
||||||
|
|
||||||
grads = tf.nest.map_structure(self._reduce_per_example_grads,
|
grads = tf.nest.map_structure(self._reduce_per_example_grads,
|
||||||
per_eg_grads)
|
per_eg_grads)
|
||||||
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
|
||||||
|
|
|
@ -92,11 +92,74 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
self.assertAllClose(model_weights[0], expected_weights)
|
self.assertAllClose(model_weights[0], expected_weights)
|
||||||
self.assertAllClose(model_weights[1], expected_bias)
|
self.assertAllClose(model_weights[1], expected_bias)
|
||||||
|
|
||||||
|
def _compute_expected_gradients(self, data, labels, w, l2_norm_clip,
|
||||||
|
num_microbatches):
|
||||||
|
batch_size = data.shape[0]
|
||||||
|
if num_microbatches is None:
|
||||||
|
num_microbatches = batch_size
|
||||||
|
|
||||||
|
preds = np.matmul(data, w)
|
||||||
|
|
||||||
|
grads = 2 * data * (labels - preds)[:, np.newaxis]
|
||||||
|
grads = np.reshape(grads,
|
||||||
|
[num_microbatches, batch_size // num_microbatches, -1])
|
||||||
|
|
||||||
|
mb_grads = np.mean(grads, axis=1)
|
||||||
|
mb_grad_norms = np.linalg.norm(mb_grads, axis=1)
|
||||||
|
|
||||||
|
scale = np.minimum(l2_norm_clip / mb_grad_norms, 1.0)
|
||||||
|
|
||||||
|
mb_grads = mb_grads * scale[:, np.newaxis]
|
||||||
|
|
||||||
|
final_grads = np.mean(mb_grads, axis=0)
|
||||||
|
return final_grads
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
('noise_multiplier 3 2', 3.0, 2.0),
|
('mb_test 0', 1.0, None),
|
||||||
('noise_multiplier 5 4', 5.0, 4.0),
|
('mb_test 1', 1.0, 1),
|
||||||
|
('mb_test 2', 1.0, 2),
|
||||||
|
('mb_test 4', 1.0, 4),
|
||||||
)
|
)
|
||||||
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier):
|
def testMicrobatches(self, l2_norm_clip, num_microbatches):
|
||||||
|
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||||
|
w = np.zeros((2))
|
||||||
|
train_labels = np.array([1.0, 3.0, -2.0, -4.0])
|
||||||
|
learning_rate = 1.0
|
||||||
|
|
||||||
|
expected_grads = self._compute_expected_gradients(train_data, train_labels,
|
||||||
|
w, l2_norm_clip,
|
||||||
|
num_microbatches)
|
||||||
|
expected_weights = np.squeeze(learning_rate * expected_grads)
|
||||||
|
|
||||||
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
loss = tf.keras.losses.MeanSquaredError()
|
||||||
|
|
||||||
|
# Simple linear model returns w * x + b.
|
||||||
|
model = dp_keras_model.DPSequential(
|
||||||
|
l2_norm_clip=l2_norm_clip,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
layers=[
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
1, use_bias=False, kernel_initializer='zeros')
|
||||||
|
])
|
||||||
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
|
||||||
|
model_weights = np.squeeze(model.get_weights())
|
||||||
|
self.assertAllClose(model_weights, expected_weights)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('noise_multiplier 3 2 1', 3.0, 2.0, 1),
|
||||||
|
('noise_multiplier 5 4 1', 5.0, 4.0, 1),
|
||||||
|
('noise_multiplier 3 2 2', 3.0, 2.0, 2),
|
||||||
|
('noise_multiplier 5 4 2', 5.0, 4.0, 2),
|
||||||
|
('noise_multiplier 3 2 4', 3.0, 2.0, 4),
|
||||||
|
('noise_multiplier 5 4 4', 5.0, 4.0, 4),
|
||||||
|
)
|
||||||
|
def testNoiseMultiplier(self, l2_norm_clip, noise_multiplier,
|
||||||
|
num_microbatches):
|
||||||
# The idea behind this test is to start with a model whose parameters
|
# The idea behind this test is to start with a model whose parameters
|
||||||
# are set to zero. We then run one step of a model that produces
|
# are set to zero. We then run one step of a model that produces
|
||||||
# an un-noised gradient of zero, and then compute the standard deviation
|
# an un-noised gradient of zero, and then compute the standard deviation
|
||||||
|
@ -104,8 +167,8 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
# deviation.
|
# deviation.
|
||||||
|
|
||||||
# Data is one example of length 1000, set to zero, with label zero.
|
# Data is one example of length 1000, set to zero, with label zero.
|
||||||
train_data = np.zeros((1, 1000))
|
train_data = np.zeros((4, 1000))
|
||||||
train_labels = np.array([0.0])
|
train_labels = np.array([0.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
learning_rate = 1.0
|
learning_rate = 1.0
|
||||||
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
@ -115,21 +178,67 @@ class DPKerasModelTest(tf.test.TestCase, parameterized.TestCase):
|
||||||
model = dp_keras_model.DPSequential(
|
model = dp_keras_model.DPSequential(
|
||||||
l2_norm_clip=l2_norm_clip,
|
l2_norm_clip=l2_norm_clip,
|
||||||
noise_multiplier=noise_multiplier,
|
noise_multiplier=noise_multiplier,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
layers=[
|
layers=[
|
||||||
tf.keras.layers.InputLayer(input_shape=(1000,)),
|
tf.keras.layers.InputLayer(input_shape=(1000,)),
|
||||||
tf.keras.layers.Dense(
|
tf.keras.layers.Dense(
|
||||||
1, kernel_initializer='zeros', bias_initializer='zeros')
|
1, kernel_initializer='zeros', bias_initializer='zeros')
|
||||||
])
|
])
|
||||||
model.compile(optimizer=optimizer, loss=loss)
|
model.compile(optimizer=optimizer, loss=loss)
|
||||||
model.fit(train_data, train_labels, epochs=1, batch_size=1)
|
model.fit(train_data, train_labels, epochs=1, batch_size=4)
|
||||||
|
|
||||||
model_weights = model.get_weights()
|
model_weights = model.get_weights()
|
||||||
measured_std = np.std(model_weights[0])
|
measured_std = np.std(model_weights[0])
|
||||||
expected_std = l2_norm_clip * noise_multiplier
|
expected_std = l2_norm_clip * noise_multiplier / num_microbatches
|
||||||
|
|
||||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||||
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
|
self.assertNear(measured_std, expected_std, 0.1 * expected_std)
|
||||||
|
|
||||||
|
# Simple check to make sure dimensions are correct when output has
|
||||||
|
# dimension > 1.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('mb_test None 1', None, 1),
|
||||||
|
('mb_test 1 2', 1, 2),
|
||||||
|
('mb_test 2 2', 2, 2),
|
||||||
|
('mb_test 4 4', 4, 4),
|
||||||
|
)
|
||||||
|
def testMultiDimensionalOutput(self, num_microbatches, output_dimension):
|
||||||
|
train_data = np.array([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0], [8.0, 9.0]])
|
||||||
|
train_labels = np.array([0, 1, 1, 0])
|
||||||
|
learning_rate = 1.0
|
||||||
|
|
||||||
|
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
|
||||||
|
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
|
||||||
|
model = dp_keras_model.DPSequential(
|
||||||
|
l2_norm_clip=1.0e9,
|
||||||
|
noise_multiplier=0.0,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
layers=[
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
output_dimension, use_bias=False, kernel_initializer='zeros')
|
||||||
|
])
|
||||||
|
model.compile(optimizer=optimizer, loss=loss_fn)
|
||||||
|
model.fit(train_data, train_labels, epochs=1, batch_size=4, shuffle=False)
|
||||||
|
|
||||||
|
# Checks that calls to earlier API using `use_xla` as a positional argument
|
||||||
|
# raise an exception.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('earlier API True', True),
|
||||||
|
('earlier API False', False),
|
||||||
|
)
|
||||||
|
def testEarlierAPIFails(self, use_xla):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = dp_keras_model.DPSequential(
|
||||||
|
1.0e9,
|
||||||
|
0.0,
|
||||||
|
use_xla,
|
||||||
|
layers=[
|
||||||
|
tf.keras.layers.InputLayer(input_shape=(2,)),
|
||||||
|
tf.keras.layers.Dense(
|
||||||
|
2, use_bias=False, kernel_initializer='zeros')
|
||||||
|
])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
|
Loading…
Reference in a new issue