forked from 626_privacy/tensorflow_privacy
Adds a Keras optimizer version of DP-SGD. New optimizers are subclasses of tf.keras.optimizers.Optimizer and override both _compute_gradients and get_gradients.
PiperOrigin-RevId: 325124698
This commit is contained in:
parent
191f2461c5
commit
e91c820b2a
2 changed files with 418 additions and 0 deletions
161
tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py
Normal file
161
tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Differentially private version of Keras optimizer v2."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.dp_query import gaussian_query
|
||||
|
||||
|
||||
def make_keras_optimizer_class(cls):
|
||||
"""Constructs a DP Keras optimizer class from an existing one."""
|
||||
|
||||
class DPOptimizerClass(cls):
|
||||
"""Differentially private subclass of given class cls.
|
||||
|
||||
The class tf.keras.optimizers.Optimizer has two methods to compute
|
||||
gradients, `_compute_gradients` and `get_gradients`. The first works
|
||||
with eager execution, while the second runs in graph mode and is used
|
||||
by canned estimators.
|
||||
|
||||
Internally, DPOptimizerClass stores hyperparameters both individually
|
||||
and encapsulated in a `GaussianSumQuery` object for these two use cases.
|
||||
However, this should be invisible to users of this class.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches=None,
|
||||
*args, # pylint: disable=keyword-arg-before-vararg, g-doc-args
|
||||
**kwargs):
|
||||
"""Initialize the DPOptimizerClass.
|
||||
|
||||
Args:
|
||||
l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients)
|
||||
noise_multiplier: Ratio of the standard deviation to the clipping norm
|
||||
num_microbatches: The number of microbatches into which each minibatch
|
||||
is split.
|
||||
"""
|
||||
super(DPOptimizerClass, self).__init__(*args, **kwargs)
|
||||
self._l2_norm_clip = l2_norm_clip
|
||||
self._noise_multiplier = noise_multiplier
|
||||
self._num_microbatches = num_microbatches
|
||||
self._dp_sum_query = gaussian_query.GaussianSumQuery(
|
||||
l2_norm_clip, l2_norm_clip * noise_multiplier)
|
||||
self._global_state = self._dp_sum_query.initial_global_state()
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None):
|
||||
"""DP version of superclass method."""
|
||||
|
||||
# Compute loss.
|
||||
if not callable(loss) and tape is None:
|
||||
raise ValueError('`tape` is required when a `Tensor` loss is passed.')
|
||||
tape = tape if tape is not None else tf.GradientTape()
|
||||
|
||||
if callable(loss):
|
||||
with tape:
|
||||
if not callable(var_list):
|
||||
tape.watch(var_list)
|
||||
|
||||
if callable(loss):
|
||||
loss = loss()
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
if callable(var_list):
|
||||
var_list = var_list()
|
||||
else:
|
||||
microbatch_losses = tf.reduce_mean(
|
||||
tf.reshape(loss, [self._num_microbatches, -1]), axis=1)
|
||||
|
||||
var_list = tf.nest.flatten(var_list)
|
||||
|
||||
# Compute the per-microbatch losses using helpful jacobian method.
|
||||
with tf.keras.backend.name_scope(self._name + '/gradients'):
|
||||
jacobian = tape.jacobian(microbatch_losses, var_list)
|
||||
|
||||
# Clip gradients to given l2_norm_clip.
|
||||
def clip_gradients(g):
|
||||
return tf.clip_by_global_norm(g, self._l2_norm_clip)[0]
|
||||
|
||||
clipped_gradients = tf.map_fn(clip_gradients, jacobian)
|
||||
|
||||
def reduce_noise_normalize_batch(g):
|
||||
# Sum gradients over all microbatches.
|
||||
summed_gradient = tf.reduce_sum(g, axis=0)
|
||||
|
||||
# Add noise to summed gradients.
|
||||
noise_stddev = self._l2_norm_clip * self._noise_multiplier
|
||||
noise = tf.random.normal(
|
||||
tf.shape(input=summed_gradient), stddev=noise_stddev)
|
||||
noised_gradient = tf.add(summed_gradient, noise)
|
||||
|
||||
# Normalize by number of microbatches and return.
|
||||
return tf.truediv(noised_gradient, self._num_microbatches)
|
||||
|
||||
final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch,
|
||||
clipped_gradients)
|
||||
|
||||
return list(zip(final_gradients, var_list))
|
||||
|
||||
def get_gradients(self, loss, params):
|
||||
"""DP version of superclass method."""
|
||||
|
||||
# This code mostly follows the logic in the original DPOptimizerClass
|
||||
# in dp_optimizer.py, except that this returns only the gradients,
|
||||
# not the gradients and variables.
|
||||
microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1])
|
||||
sample_params = (
|
||||
self._dp_sum_query.derive_sample_params(self._global_state))
|
||||
|
||||
def process_microbatch(i, sample_state):
|
||||
"""Process one microbatch (record) with privacy helper."""
|
||||
mean_loss = tf.reduce_mean(
|
||||
input_tensor=tf.gather(microbatch_losses, [i]))
|
||||
grads = tf.gradients(mean_loss, params)
|
||||
sample_state = self._dp_sum_query.accumulate_record(
|
||||
sample_params, sample_state, grads)
|
||||
return sample_state
|
||||
|
||||
sample_state = self._dp_sum_query.initial_sample_state(params)
|
||||
for idx in range(self._num_microbatches):
|
||||
sample_state = process_microbatch(idx, sample_state)
|
||||
grad_sums, self._global_state = (
|
||||
self._dp_sum_query.get_noised_result(sample_state,
|
||||
self._global_state))
|
||||
|
||||
def normalize(v):
|
||||
try:
|
||||
return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
final_grads = tf.nest.map_structure(normalize, grad_sums)
|
||||
|
||||
return final_grads
|
||||
|
||||
return DPOptimizerClass
|
||||
|
||||
|
||||
DPKerasAdagradOptimizer = make_keras_optimizer_class(
|
||||
tf.keras.optimizers.Adagrad)
|
||||
DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam)
|
||||
DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD)
|
257
tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py
Normal file
257
tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py
Normal file
|
@ -0,0 +1,257 @@
|
|||
# Copyright 2019, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for dp_optimizer_keras.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras
|
||||
|
||||
|
||||
class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests for _compute_gradients method."""
|
||||
|
||||
def _loss(self, val0, val1):
|
||||
"""Loss function whose derivative w.r.t val1 is val1 - val0."""
|
||||
return 0.5 * tf.reduce_sum(
|
||||
input_tensor=tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches, expected gradient for
|
||||
# var0, expected gradient for var1.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4,
|
||||
[-2.5, -2.5], [-0.5]),
|
||||
)
|
||||
def testBaseline(self, cls, num_microbatches, expected_grad0, expected_grad1):
|
||||
var0 = tf.Variable([1.0, 2.0])
|
||||
var1 = tf.Variable([3.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]])
|
||||
data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]])
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=100.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0) + self._loss(data1, var1)
|
||||
|
||||
grads_and_vars = opt._compute_gradients(loss, [var0, var1])
|
||||
self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0])
|
||||
self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),)
|
||||
def testClippingNorm(self, cls):
|
||||
var0 = tf.Variable([0.0, 0.0])
|
||||
data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]])
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=1.0,
|
||||
noise_multiplier=0.0,
|
||||
num_microbatches=1,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
# Expected gradient is sum of differences.
|
||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
||||
self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0,
|
||||
1.0, 4),
|
||||
)
|
||||
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches):
|
||||
var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32))
|
||||
data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32))
|
||||
|
||||
opt = cls(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=2.0)
|
||||
|
||||
loss = lambda: self._loss(data0, var0)
|
||||
grads_and_vars = opt._compute_gradients(loss, [var0])
|
||||
grads = grads_and_vars[0][0].numpy()
|
||||
|
||||
# Test standard deviation is close to l2_norm_clip * noise_multiplier.
|
||||
self.assertNear(
|
||||
np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||
|
||||
|
||||
class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase):
|
||||
"""Tests for get_gradient method.
|
||||
|
||||
Since get_gradients must run in graph mode, the method is tested within
|
||||
the Estimator framework.
|
||||
"""
|
||||
|
||||
def _make_linear_model_fn(self, opt_cls, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches, learning_rate):
|
||||
"""Returns a model function for a linear regressor."""
|
||||
|
||||
def linear_model_fn(features, labels, mode):
|
||||
layer = tf.keras.layers.Dense(
|
||||
1,
|
||||
activation='linear',
|
||||
name='dense',
|
||||
kernel_initializer='zeros',
|
||||
bias_initializer='zeros')
|
||||
preds = layer.apply(features)
|
||||
|
||||
vector_loss = 0.5 * tf.math.squared_difference(labels, preds)
|
||||
scalar_loss = tf.reduce_mean(input_tensor=vector_loss)
|
||||
|
||||
optimizer = opt_cls(
|
||||
l2_norm_clip=l2_norm_clip,
|
||||
noise_multiplier=noise_multiplier,
|
||||
num_microbatches=num_microbatches,
|
||||
learning_rate=learning_rate)
|
||||
|
||||
params = layer.trainable_weights
|
||||
global_step = tf.compat.v1.train.get_global_step()
|
||||
train_op = tf.group(
|
||||
optimizer.get_updates(loss=vector_loss, params=params),
|
||||
[tf.compat.v1.assign_add(global_step, 1)])
|
||||
return tf.estimator.EstimatorSpec(
|
||||
mode=mode, loss=scalar_loss, train_op=train_op)
|
||||
|
||||
return linear_model_fn
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),
|
||||
('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2),
|
||||
('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),)
|
||||
def testBaseline(self, cls, num_microbatches):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
linear_regressor = tf.estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(cls, 100.0, 0.0, num_microbatches,
|
||||
0.05))
|
||||
|
||||
true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32)
|
||||
true_bias = np.array([6.0]).astype(np.float32)
|
||||
train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32)
|
||||
|
||||
train_labels = np.matmul(train_data,
|
||||
true_weights) + true_bias + np.random.normal(
|
||||
scale=0.0, size=(1000, 1)).astype(np.float32)
|
||||
|
||||
def train_input_fn():
|
||||
return tf.data.Dataset.from_tensor_slices(
|
||||
(train_data, train_labels)).batch(8)
|
||||
|
||||
linear_regressor.train(input_fn=train_input_fn, steps=125)
|
||||
|
||||
self.assertAllClose(
|
||||
linear_regressor.get_variable_value('dense/kernel'),
|
||||
true_weights,
|
||||
atol=0.05)
|
||||
self.assertAllClose(
|
||||
linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),)
|
||||
def testClippingNorm(self, cls, num_microbatches):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32)
|
||||
true_bias = np.array([0]).astype(np.float32)
|
||||
|
||||
train_data = np.array([[1.0, 0.0, 0.0, 0.0]]).astype(np.float32)
|
||||
train_labels = np.matmul(train_data, true_weights) + true_bias
|
||||
|
||||
def train_input_fn():
|
||||
return tf.data.Dataset.from_tensor_slices(
|
||||
(train_data, train_labels)).batch(1)
|
||||
|
||||
unclipped_linear_regressor = tf.estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(cls, 1.0e9, 0.0, num_microbatches,
|
||||
1.0))
|
||||
unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||
|
||||
kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel')
|
||||
bias_value = unclipped_linear_regressor.get_variable_value('dense/bias')
|
||||
global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value])))
|
||||
|
||||
clipped_linear_regressor = tf.estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(cls, 1.0, 0.0, num_microbatches,
|
||||
1.0))
|
||||
clipped_linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||
|
||||
self.assertAllClose(
|
||||
clipped_linear_regressor.get_variable_value('dense/kernel'),
|
||||
kernel_value / global_norm,
|
||||
atol=0.001)
|
||||
self.assertAllClose(
|
||||
clipped_linear_regressor.get_variable_value('dense/bias'),
|
||||
bias_value / global_norm,
|
||||
atol=0.001)
|
||||
|
||||
# Parameters for testing: optimizer, num_microbatches.
|
||||
@parameterized.named_parameters(
|
||||
('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0,
|
||||
4.0, 1),
|
||||
('DPGradientDescent 3 2 4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0,
|
||||
2.0, 4),
|
||||
('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0,
|
||||
6.0, 8),
|
||||
)
|
||||
def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier,
|
||||
num_microbatches):
|
||||
"""Tests that DP optimizers work with tf.estimator."""
|
||||
|
||||
linear_regressor = tf.estimator.Estimator(
|
||||
model_fn=self._make_linear_model_fn(
|
||||
cls,
|
||||
l2_norm_clip,
|
||||
noise_multiplier,
|
||||
num_microbatches,
|
||||
learning_rate=1.0))
|
||||
|
||||
true_weights = np.zeros((1000, 1), dtype=np.float32)
|
||||
true_bias = np.array([0.0]).astype(np.float32)
|
||||
|
||||
train_data = np.zeros((16, 1000), dtype=np.float32)
|
||||
train_labels = np.matmul(train_data, true_weights) + true_bias
|
||||
|
||||
def train_input_fn():
|
||||
return tf.data.Dataset.from_tensor_slices(
|
||||
(train_data, train_labels)).batch(16)
|
||||
|
||||
linear_regressor.train(input_fn=train_input_fn, steps=1)
|
||||
|
||||
kernel_value = linear_regressor.get_variable_value('dense/kernel')
|
||||
self.assertNear(
|
||||
np.std(kernel_value),
|
||||
l2_norm_clip * noise_multiplier / num_microbatches, 0.5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
Loading…
Reference in a new issue