diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py new file mode 100644 index 0000000..2fff9c2 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras.py @@ -0,0 +1,161 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Differentially private version of Keras optimizer v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.dp_query import gaussian_query + + +def make_keras_optimizer_class(cls): + """Constructs a DP Keras optimizer class from an existing one.""" + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls. + + The class tf.keras.optimizers.Optimizer has two methods to compute + gradients, `_compute_gradients` and `get_gradients`. The first works + with eager execution, while the second runs in graph mode and is used + by canned estimators. + + Internally, DPOptimizerClass stores hyperparameters both individually + and encapsulated in a `GaussianSumQuery` object for these two use cases. + However, this should be invisible to users of this class. + """ + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initialize the DPOptimizerClass. + + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients) + noise_multiplier: Ratio of the standard deviation to the clipping norm + num_microbatches: The number of microbatches into which each minibatch + is split. + """ + super(DPOptimizerClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + self._num_microbatches = num_microbatches + self._dp_sum_query = gaussian_query.GaussianSumQuery( + l2_norm_clip, l2_norm_clip * noise_multiplier) + self._global_state = self._dp_sum_query.initial_global_state() + + def _compute_gradients(self, loss, var_list, grad_loss=None, tape=None): + """DP version of superclass method.""" + + # Compute loss. + if not callable(loss) and tape is None: + raise ValueError('`tape` is required when a `Tensor` loss is passed.') + tape = tape if tape is not None else tf.GradientTape() + + if callable(loss): + with tape: + if not callable(var_list): + tape.watch(var_list) + + if callable(loss): + loss = loss() + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + + if callable(var_list): + var_list = var_list() + else: + microbatch_losses = tf.reduce_mean( + tf.reshape(loss, [self._num_microbatches, -1]), axis=1) + + var_list = tf.nest.flatten(var_list) + + # Compute the per-microbatch losses using helpful jacobian method. + with tf.keras.backend.name_scope(self._name + '/gradients'): + jacobian = tape.jacobian(microbatch_losses, var_list) + + # Clip gradients to given l2_norm_clip. + def clip_gradients(g): + return tf.clip_by_global_norm(g, self._l2_norm_clip)[0] + + clipped_gradients = tf.map_fn(clip_gradients, jacobian) + + def reduce_noise_normalize_batch(g): + # Sum gradients over all microbatches. + summed_gradient = tf.reduce_sum(g, axis=0) + + # Add noise to summed gradients. + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal( + tf.shape(input=summed_gradient), stddev=noise_stddev) + noised_gradient = tf.add(summed_gradient, noise) + + # Normalize by number of microbatches and return. + return tf.truediv(noised_gradient, self._num_microbatches) + + final_gradients = tf.nest.map_structure(reduce_noise_normalize_batch, + clipped_gradients) + + return list(zip(final_gradients, var_list)) + + def get_gradients(self, loss, params): + """DP version of superclass method.""" + + # This code mostly follows the logic in the original DPOptimizerClass + # in dp_optimizer.py, except that this returns only the gradients, + # not the gradients and variables. + microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + sample_params = ( + self._dp_sum_query.derive_sample_params(self._global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + mean_loss = tf.reduce_mean( + input_tensor=tf.gather(microbatch_losses, [i])) + grads = tf.gradients(mean_loss, params) + sample_state = self._dp_sum_query.accumulate_record( + sample_params, sample_state, grads) + return sample_state + + sample_state = self._dp_sum_query.initial_sample_state(params) + for idx in range(self._num_microbatches): + sample_state = process_microbatch(idx, sample_state) + grad_sums, self._global_state = ( + self._dp_sum_query.get_noised_result(sample_state, + self._global_state)) + + def normalize(v): + try: + return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) + except TypeError: + return None + + final_grads = tf.nest.map_structure(normalize, grad_sums) + + return final_grads + + return DPOptimizerClass + + +DPKerasAdagradOptimizer = make_keras_optimizer_class( + tf.keras.optimizers.Adagrad) +DPKerasAdamOptimizer = make_keras_optimizer_class(tf.keras.optimizers.Adam) +DPKerasSGDOptimizer = make_keras_optimizer_class(tf.keras.optimizers.SGD) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py new file mode 100644 index 0000000..f98652f --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_keras_test.py @@ -0,0 +1,257 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for dp_optimizer_keras.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.optimizers import dp_optimizer_keras + + +class DPOptimizerComputeGradientsTest(tf.test.TestCase, parameterized.TestCase): + """Tests for _compute_gradients method.""" + + def _loss(self, val0, val1): + """Loss function whose derivative w.r.t val1 is val1 - val0.""" + return 0.5 * tf.reduce_sum( + input_tensor=tf.math.squared_difference(val0, val1), axis=1) + + # Parameters for testing: optimizer, num_microbatches, expected gradient for + # var0, expected gradient for var1. + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1, + [-2.5, -2.5], [-0.5]), + ('DPAdam 2', dp_optimizer_keras.DPKerasAdamOptimizer, 2, + [-2.5, -2.5], [-0.5]), + ('DPAdagrad 4', dp_optimizer_keras.DPKerasAdagradOptimizer, 4, + [-2.5, -2.5], [-0.5]), + ) + def testBaseline(self, cls, num_microbatches, expected_grad0, expected_grad1): + var0 = tf.Variable([1.0, 2.0]) + var1 = tf.Variable([3.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + data1 = tf.Variable([[8.0], [2.0], [3.0], [1.0]]) + + opt = cls( + l2_norm_clip=100.0, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + learning_rate=2.0) + + loss = lambda: self._loss(data0, var0) + self._loss(data1, var1) + + grads_and_vars = opt._compute_gradients(loss, [var0, var1]) + self.assertAllCloseAccordingToType(expected_grad0, grads_and_vars[0][0]) + self.assertAllCloseAccordingToType(expected_grad1, grads_and_vars[1][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer_keras.DPKerasSGDOptimizer),) + def testClippingNorm(self, cls): + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + opt = cls( + l2_norm_clip=1.0, + noise_multiplier=0.0, + num_microbatches=1, + learning_rate=2.0) + + loss = lambda: self._loss(data0, var0) + # Expected gradient is sum of differences. + grads_and_vars = opt._compute_gradients(loss, [var0]) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, + 4.0, 1), + ('DPGradientDescent 4 1 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4.0, + 1.0, 4), + ) + def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, + num_microbatches): + var0 = tf.Variable(tf.zeros([1000], dtype=tf.float32)) + data0 = tf.Variable(tf.zeros([16, 1000], dtype=tf.float32)) + + opt = cls( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, + learning_rate=2.0) + + loss = lambda: self._loss(data0, var0) + grads_and_vars = opt._compute_gradients(loss, [var0]) + grads = grads_and_vars[0][0].numpy() + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear( + np.std(grads), l2_norm_clip * noise_multiplier / num_microbatches, 0.5) + + +class DPOptimizerGetGradientsTest(tf.test.TestCase, parameterized.TestCase): + """Tests for get_gradient method. + + Since get_gradients must run in graph mode, the method is tested within + the Estimator framework. + """ + + def _make_linear_model_fn(self, opt_cls, l2_norm_clip, noise_multiplier, + num_microbatches, learning_rate): + """Returns a model function for a linear regressor.""" + + def linear_model_fn(features, labels, mode): + layer = tf.keras.layers.Dense( + 1, + activation='linear', + name='dense', + kernel_initializer='zeros', + bias_initializer='zeros') + preds = layer.apply(features) + + vector_loss = 0.5 * tf.math.squared_difference(labels, preds) + scalar_loss = tf.reduce_mean(input_tensor=vector_loss) + + optimizer = opt_cls( + l2_norm_clip=l2_norm_clip, + noise_multiplier=noise_multiplier, + num_microbatches=num_microbatches, + learning_rate=learning_rate) + + params = layer.trainable_weights + global_step = tf.compat.v1.train.get_global_step() + train_op = tf.group( + optimizer.get_updates(loss=vector_loss, params=params), + [tf.compat.v1.assign_add(global_step, 1)]) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) + + return linear_model_fn + + # Parameters for testing: optimizer, num_microbatches. + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1), + ('DPGradientDescent 2', dp_optimizer_keras.DPKerasSGDOptimizer, 2), + ('DPGradientDescent 4', dp_optimizer_keras.DPKerasSGDOptimizer, 4),) + def testBaseline(self, cls, num_microbatches): + """Tests that DP optimizers work with tf.estimator.""" + + linear_regressor = tf.estimator.Estimator( + model_fn=self._make_linear_model_fn(cls, 100.0, 0.0, num_microbatches, + 0.05)) + + true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) + true_bias = np.array([6.0]).astype(np.float32) + train_data = np.random.normal(scale=3.0, size=(1000, 4)).astype(np.float32) + + train_labels = np.matmul(train_data, + true_weights) + true_bias + np.random.normal( + scale=0.0, size=(1000, 1)).astype(np.float32) + + def train_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(8) + + linear_regressor.train(input_fn=train_input_fn, steps=125) + + self.assertAllClose( + linear_regressor.get_variable_value('dense/kernel'), + true_weights, + atol=0.05) + self.assertAllClose( + linear_regressor.get_variable_value('dense/bias'), true_bias, atol=0.05) + + # Parameters for testing: optimizer, num_microbatches. + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer_keras.DPKerasSGDOptimizer, 1),) + def testClippingNorm(self, cls, num_microbatches): + """Tests that DP optimizers work with tf.estimator.""" + + true_weights = np.array([[6.0], [0.0], [0], [0]]).astype(np.float32) + true_bias = np.array([0]).astype(np.float32) + + train_data = np.array([[1.0, 0.0, 0.0, 0.0]]).astype(np.float32) + train_labels = np.matmul(train_data, true_weights) + true_bias + + def train_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(1) + + unclipped_linear_regressor = tf.estimator.Estimator( + model_fn=self._make_linear_model_fn(cls, 1.0e9, 0.0, num_microbatches, + 1.0)) + unclipped_linear_regressor.train(input_fn=train_input_fn, steps=1) + + kernel_value = unclipped_linear_regressor.get_variable_value('dense/kernel') + bias_value = unclipped_linear_regressor.get_variable_value('dense/bias') + global_norm = np.linalg.norm(np.concatenate((kernel_value, [bias_value]))) + + clipped_linear_regressor = tf.estimator.Estimator( + model_fn=self._make_linear_model_fn(cls, 1.0, 0.0, num_microbatches, + 1.0)) + clipped_linear_regressor.train(input_fn=train_input_fn, steps=1) + + self.assertAllClose( + clipped_linear_regressor.get_variable_value('dense/kernel'), + kernel_value / global_norm, + atol=0.001) + self.assertAllClose( + clipped_linear_regressor.get_variable_value('dense/bias'), + bias_value / global_norm, + atol=0.001) + + # Parameters for testing: optimizer, num_microbatches. + @parameterized.named_parameters( + ('DPGradientDescent 2 4 1', dp_optimizer_keras.DPKerasSGDOptimizer, 2.0, + 4.0, 1), + ('DPGradientDescent 3 2 4', dp_optimizer_keras.DPKerasSGDOptimizer, 3.0, + 2.0, 4), + ('DPGradientDescent 8 6 8', dp_optimizer_keras.DPKerasSGDOptimizer, 8.0, + 6.0, 8), + ) + def testNoiseMultiplier(self, cls, l2_norm_clip, noise_multiplier, + num_microbatches): + """Tests that DP optimizers work with tf.estimator.""" + + linear_regressor = tf.estimator.Estimator( + model_fn=self._make_linear_model_fn( + cls, + l2_norm_clip, + noise_multiplier, + num_microbatches, + learning_rate=1.0)) + + true_weights = np.zeros((1000, 1), dtype=np.float32) + true_bias = np.array([0.0]).astype(np.float32) + + train_data = np.zeros((16, 1000), dtype=np.float32) + train_labels = np.matmul(train_data, true_weights) + true_bias + + def train_input_fn(): + return tf.data.Dataset.from_tensor_slices( + (train_data, train_labels)).batch(16) + + linear_regressor.train(input_fn=train_input_fn, steps=1) + + kernel_value = linear_regressor.get_variable_value('dense/kernel') + self.assertNear( + np.std(kernel_value), + l2_norm_clip * noise_multiplier / num_microbatches, 0.5) + + +if __name__ == '__main__': + tf.test.main()