# Copyright 2018, The TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Bolton model for bolton method of differentially private ML""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers from tensorflow.python.training.tracking import base as trackable from tensorflow.python.framework import ops as _ops from privacy.bolton.loss import StrongConvexLoss from privacy.bolton.optimizer import Private class Bolton(Model): """ Bolton episilon-delta model Uses 4 key steps to achieve privacy guarantees: 1. Adds noise to weights after training (output perturbation). 2. Projects weights to R after each batch 3. Limits learning rate 4. Use a strongly convex loss function (see compile) """ def __init__(self, n_classes, epsilon, noise_distribution='laplace', weights_initializer=tf.initializers.GlorotUniform(), seed=1, dtype=tf.float32 ): """ private constructor. Args: n_classes: number of output classes to predict. epsilon: level of privacy guarantee noise_distribution: distribution to pull weight perturbations from weights_initializer: initializer for weights seed: random seed to use dtype: data type to use for tensors """ class MyCustomCallback(tf.keras.callbacks.Callback): """Custom callback for bolton training requirements. Implements steps (see Bolton class): 2. Projects weights to R after each batch 3. Limits learning rate """ def on_train_batch_end(self, batch, logs=None): loss = self.model.loss self.model.optimizer.limit_learning_rate( self.model.run_eagerly, loss.beta(self.model.class_weight), loss.gamma() ) self.model._project_weights_to_r(loss.radius(), False) def on_train_end(self, logs=None): loss = self.model.loss self.model._project_weights_to_r(loss.radius(), True) super(Bolton, self).__init__(name='bolton', dynamic=False) self.n_classes = n_classes self.output_layer = tf.keras.layers.Dense( self.n_classes, kernel_regularizer=tf.keras.regularizers.l2(), kernel_initializer=weights_initializer, ) # if we do regularization here, we require the user to re-instantiate # the model each time they want to # change lambda, unless we standardize modifying it later at .compile self.force = False self.noise_distribution = noise_distribution self.epsilon = epsilon self.seed = seed self.__in_fit = False self._callback = MyCustomCallback() self._dtype = dtype def call(self, inputs): """Forward pass of network Args: inputs: inputs to neural network Returns: """ return self.output_layer(inputs) def compile(self, optimizer='SGD', loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None, distribute=None, **kwargs): """See super class. Default optimizer used in Bolton method is SGD. """ if not isinstance(loss, StrongConvexLoss): raise ValueError("Loss must be subclassed from StrongConvexLoss") self.output_layer.kernel_regularizer.l2 = loss.reg_lambda() if not isinstance(optimizer, Private): optimizer = optimizers.get(optimizer) if isinstance(self.optimizer, trackable.Trackable): self._track_trackable( self.optimizer, name='optimizer', overwrite=True ) optimizer = Private(optimizer) super(Bolton, self).compile(optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights, sample_weight_mode=sample_weight_mode, weighted_metrics=weighted_metrics, target_tensors=target_tensors, distribute=distribute, **kwargs ) def _post_fit(self, x, n_samples): """Implements 1-time weight changes needed for Bolton method. In this case, specifically implements the noise addition assuming a strongly convex function. Args: x: inputs n_samples: number of samples in the inputs. In case the number cannot be readily determined by inspecting x. Returns: """ if n_samples is not None: data_size = n_samples elif hasattr(x, 'shape'): data_size = x.shape[0] elif hasattr(x, "__len__"): data_size = len(x) else: if n_samples is None: raise ValueError("Unable to detect the number of training " "samples and n_smaples was None. " "either pass a dataset with a .shape or " "__len__ attribute or explicitly pass the " "number of samples as n_smaples.") data_size = n_samples for layer in self._layers: layer.kernel = layer.kernel + self._get_noise( self.noise_distribution, data_size ) def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_freq=1, max_queue_size=10, workers=1, use_multiprocessing=False, n_samples=None, **kwargs): """Reroutes to super fit with additional Bolton delta-epsilon privacy requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 Requirements are as follows: 1. Adds noise to weights after training (output perturbation). 2. Projects weights to R after each batch 3. Limits learning rate 4. Use a strongly convex loss function (see compile) See super implementation for more details. Args: n_samples: the number of individual samples in x. Returns: """ self.__in_fit = True cb = [self._callback] if callbacks is not None: cb.extend(callbacks) callbacks = cb if class_weight is None: class_weight = self.calculate_class_weights(class_weight) self.class_weight = class_weight out = super(Bolton, self).fit(x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=verbose, callbacks=callbacks, validation_split=validation_split, validation_data=validation_data, shuffle=shuffle, class_weight=class_weight, sample_weight=sample_weight, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, **kwargs ) self._post_fit(x, n_samples) self.__in_fit = False return out def fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, validation_freq=1, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0, n_samples=None ): """ This method is the same as fit except for when the passed dataset is a generator. See super method and fit for more details. Args: n_samples: number of individual samples in x """ if class_weight is None: class_weight = self.calculate_class_weights(class_weight) self.class_weight = class_weight out = super(Bolton, self).fit_generator( generator, steps_per_epoch=steps_per_epoch, epochs=epochs, verbose=verbose, callbacks=callbacks, validation_data=validation_data, validation_steps=validation_steps, validation_freq=validation_freq, class_weight=class_weight, max_queue_size=max_queue_size, workers=workers, use_multiprocessing=use_multiprocessing, shuffle=shuffle, initial_epoch=initial_epoch ) if not self.__in_fit: self._post_fit(generator, n_samples) return out def calculate_class_weights(self, class_weights=None, class_counts=None, num_classes=None ): """ Calculates class weighting to be used in training. Can be on Args: class_weights: str specifying type, array giving weights, or None. class_counts: If class_weights is not None, then the number of samples for each class num_classes: If class_weights is not None, then the number of classes. Returns: class_weights as 1D tensor, to be passed to model's fit method. """ # Value checking class_keys = ['balanced'] is_string = False if isinstance(class_weights, str): is_string = True if class_weights not in class_keys: raise ValueError("Detected string class_weights with " "value: {0}, which is not one of {1}." "Please select a valid class_weight type" "or pass an array".format(class_weights, class_keys)) if class_counts is None: raise ValueError("Class counts must be provided if using" "class_weights=%s" % class_weights) if num_classes is None: raise ValueError("Class counts must be provided if using" "class_weights=%s" % class_weights) elif class_weights is not None: if num_classes is None: raise ValueError("You must pass a value for num_classes if" "creating an array of class_weights") # performing class weight calculation if class_weights is None: class_weights = 1 elif is_string and class_weights == 'balanced': num_samples = sum(class_counts) class_weights = tf.Variable( num_samples / (num_classes * class_counts), dtype=self._dtype ) else: class_weights = _ops.convert_to_tensor_v2(class_weights) if len(class_weights.shape) != 1: raise ValueError("Detected class_weights shape: {0} instead of " "1D array".format(class_weights.shape)) if class_weights.shape[0] != num_classes: raise ValueError( "Detected array length: {0} instead of: {1}".format( class_weights.shape[0], num_classes ) ) return class_weights def _project_weights_to_r(self, r, force=False): """helper method to normalize the weights to the R-ball. Args: r: radius of "R-Ball". Scalar to normalize to. force: True to normalize regardless of previous weight values. False to check if weights > R-ball and only normalize then. Returns: """ for layer in self._layers: weight_norm = tf.norm(layer.kernel, axis=0) if force: layer.kernel = layer.kernel / (weight_norm / r) elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0: layer.kernel = layer.kernel / (weight_norm / r) def _get_noise(self, distribution, data_size): """Sample noise to be added to weights for privacy guarantee Args: distribution: the distribution type to pull noise from data_size: the number of samples Returns: noise in shape of layer's weights to be added to the weights. """ distribution = distribution.lower() input_dim = self._layers[0].kernel.numpy().shape[0] loss = self.loss if distribution == 'laplace': per_class_epsilon = self.epsilon / (self.n_classes) l2_sensitivity = (2 * loss.lipchitz_constant(self.class_weight)) / \ (loss.gamma() * data_size) unit_vector = tf.random.normal(shape=(input_dim, self.n_classes), mean=0, seed=1, stddev=1.0, dtype=self._dtype) unit_vector = unit_vector / tf.math.sqrt( tf.reduce_sum(tf.math.square(unit_vector), axis=0) ) beta = l2_sensitivity / per_class_epsilon alpha = input_dim # input_dim gamma = tf.random.gamma([self.n_classes], alpha, beta=1 / beta, seed=1, dtype=self._dtype) return unit_vector * gamma raise NotImplementedError("distribution: {0} is not " "currently supported".format(distribution))