forked from 626_privacy/tensorflow_privacy
Clarify logic in Keras version of DP-SGD optimizer, and add a unit test involving clipping on multiple variables.
PiperOrigin-RevId: 472559697
This commit is contained in:
parent
628e5bb926
commit
407e5c8e11
2 changed files with 49 additions and 1 deletions
|
@ -271,10 +271,12 @@ def make_keras_optimizer_class(cls):
|
|||
jacobian = tape.jacobian(
|
||||
microbatch_losses, var_list, unconnected_gradients='zero')
|
||||
|
||||
# Clip gradients to given l2_norm_clip.
|
||||
def clip_gradients(g):
|
||||
"""Clips gradients to given l2_norm_clip."""
|
||||
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]
|
||||
|
||||
# Clip all gradients. Note that `tf.map_fn` applies the given function
|
||||
# to its arguments unstacked along axis 0.
|
||||
clipped_gradients = tf.map_fn(clip_gradients, jacobian)
|
||||
|
||||
def reduce_noise_normalize_batch(g):
|
||||
|
|
|
@ -133,6 +133,52 @@ class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
|||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 1),
|
||||
('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 2),
|
||||
('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 2.5, 4),
|
||||
('DPGradientDescentVectorized',
|
||||
dp_optimizer_keras_vectorized.VectorizedDPKerasSGDOptimizer, 2.5, 1),
|
||||
)
|
||||
def testClippingNormMultipleVariables(self, cls, l2_clip_norm,
|
||||
num_microbatches):
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0])
|
||||
data0 = tf.Variable([[3.0, 6.0], [5.0, 6.0], [4.0, 8.0], [-1.0, 0.0]])
|
||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=l2_clip_norm,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
||||
|
||||
# Expected gradient is sum of differences.
|
||||
grads_and_vars = opt._compute_gradients(loss, [var0, var1])
|
||||
|
||||
# Compute expected gradients.
|
||||
batch_size = data0.shape[0]
|
||||
grad0 = (data0 - var0).numpy()
|
||||
grad1 = (data1 - var1).numpy()
|
||||
grads = np.concatenate([grad0, grad1], axis=1)
|
||||
|
||||
grads = np.reshape(
|
||||
grads, (num_microbatches, int(batch_size / num_microbatches), -1))
|
||||
grads = np.mean(grads, axis=1)
|
||||
|
||||
norms = np.apply_along_axis(np.linalg.norm, axis=1, arr=grads)
|
||||
grad_factors = l2_clip_norm / np.maximum(l2_clip_norm, norms)
|
||||
|
||||
scaled_grads = grads * grad_factors[:, None]
|
||||
mean_scaled_grads = -np.mean(scaled_grads, axis=0)
|
||||
expected0, expected1 = np.split(mean_scaled_grads, [2], axis=0)
|
||||
|
||||
# Compare expected with actual gradients.
|
||||
self.assertAllCloseAccordingToType(expected0, grads_and_vars[0][0])
|
||||
self.assertAllCloseAccordingToType(expected1, grads_and_vars[1][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
|
|
Loading…
Reference in a new issue