# Copyright 2018, The TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Bolton model for bolton method of differentially private ML""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers from tensorflow.python.framework import ops as _ops from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.optimizer import Private _accepted_distributions = ['laplace'] class Bolton(Model): """ Bolton episilon-delta model Uses 4 key steps to achieve privacy guarantees: 1. Adds noise to weights after training (output perturbation). 2. Projects weights to R after each batch 3. Limits learning rate 4. Use a strongly convex loss function (see compile) For more details on the strong convexity requirements, see: Bolt-on Differential Privacy for Scalable Stochastic Gradient Descent-based Analytics by Xi Wu et. al. """ def __init__(self, n_classes, epsilon, noise_distribution='laplace', seed=1, dtype=tf.float32 ): """ private constructor. Args: n_classes: number of output classes to predict. epsilon: level of privacy guarantee noise_distribution: distribution to pull weight perturbations from weights_initializer: initializer for weights seed: random seed to use dtype: data type to use for tensors """ class MyCustomCallback(tf.keras.callbacks.Callback): """Custom callback for bolton training requirements. Implements steps (see Bolton class): 2. Projects weights to R after each batch 3. Limits learning rate """ def on_train_batch_end(self, batch, logs=None): loss = self.model.loss self.model.optimizer.limit_learning_rate( self.model.run_eagerly, loss.beta(self.model.class_weight), loss.gamma() ) self.model._project_weights_to_r(loss.radius(), False) def on_train_end(self, logs=None): loss = self.model.loss self.model._project_weights_to_r(loss.radius(), True) if epsilon <= 0: raise ValueError('Detected epsilon: {0}. ' 'Valid range is 0 < epsilon R-ball and only normalize then. Returns: """ for layer in self._layers: weight_norm = tf.norm(layer.kernel, axis=0) if force: layer.kernel = layer.kernel / (weight_norm / r) elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: layer.kernel = layer.kernel / (weight_norm / r) def _get_noise(self, distribution, data_size): """Sample noise to be added to weights for privacy guarantee Args: distribution: the distribution type to pull noise from data_size: the number of samples Returns: noise in shape of layer's weights to be added to the weights. """ distribution = distribution.lower() input_dim = self._layers[0].kernel.numpy().shape[0] loss = self.loss if distribution == _accepted_distributions[0]: # laplace per_class_epsilon = self.epsilon / (self.n_classes) l2_sensitivity = (2 * loss.lipchitz_constant(self.class_weight)) / \ (loss.gamma() * data_size) unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), mean=0, seed=1, stddev=1.0, dtype=self._dtype) unit_vector = unit_vector / tf.math.sqrt( tf.reduce_sum(tf.math.square(unit_vector), axis=0) ) beta = l2_sensitivity / per_class_epsilon alpha = input_dim # input_dim gamma = tf.random.gamma([self.n_classes], alpha, beta=1 / beta, seed=1, dtype=self._dtype ) return unit_vector * gamma raise NotImplementedError('Noise distribution: {0} is not ' 'a valid distribution'.format(distribution))