tensorflow_privacy/privacy/bolton/model.py
2019-06-05 17:06:02 -04:00

402 lines
14 KiB
Python

# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton model for bolton method of differentially private ML"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexLoss
from privacy.bolton.optimizer import Private
class Bolton(Model):
"""
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
"""
def __init__(self,
n_classes,
epsilon,
noise_distribution='laplace',
weights_initializer=tf.initializers.GlorotUniform(),
seed=1,
dtype=tf.float32
):
""" private constructor.
Args:
n_classes: number of output classes to predict.
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
seed: random seed to use
dtype: data type to use for tensors
"""
class MyCustomCallback(tf.keras.callbacks.Callback):
"""Custom callback for bolton training requirements.
Implements steps (see Bolton class):
2. Projects weights to R after each batch
3. Limits learning rate
"""
def on_train_batch_end(self, batch, logs=None):
loss = self.model.loss
self.model.optimizer.limit_learning_rate(
self.model.run_eagerly,
loss.beta(self.model.class_weight),
loss.gamma()
)
self.model._project_weights_to_r(loss.radius(), False)
def on_train_end(self, logs=None):
loss = self.model.loss
self.model._project_weights_to_r(loss.radius(), True)
super(Bolton, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
kernel_regularizer=tf.keras.regularizers.l2(),
kernel_initializer=weights_initializer,
)
# if we do regularization here, we require the user to re-instantiate
# the model each time they want to
# change lambda, unless we standardize modifying it later at .compile
self.force = False
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.seed = seed
self.__in_fit = False
self._callback = MyCustomCallback()
self._dtype = dtype
def call(self, inputs):
"""Forward pass of network
Args:
inputs: inputs to neural network
Returns:
"""
return self.output_layer(inputs)
def compile(self,
optimizer='SGD',
loss=None,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
**kwargs):
"""See super class. Default optimizer used in Bolton method is SGD.
"""
if not isinstance(loss, StrongConvexLoss):
raise ValueError("Loss must be subclassed from StrongConvexLoss")
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda()
if not isinstance(optimizer, Private):
optimizer = optimizers.get(optimizer)
if isinstance(self.optimizer, trackable.Trackable):
self._track_trackable(
self.optimizer, name='optimizer', overwrite=True
)
optimizer = Private(optimizer)
super(Bolton, self).compile(optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs
)
def _post_fit(self, x, n_samples):
"""Implements 1-time weight changes needed for Bolton method.
In this case, specifically implements the noise addition
assuming a strongly convex function.
Args:
x: inputs
n_samples: number of samples in the inputs. In case the number
cannot be readily determined by inspecting x.
Returns:
"""
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, "__len__"):
data_size = len(x)
else:
if n_samples is None:
raise ValueError("Unable to detect the number of training "
"samples and n_smaples was None. "
"either pass a dataset with a .shape or "
"__len__ attribute or explicitly pass the "
"number of samples as n_smaples.")
data_size = n_samples
for layer in self._layers:
layer.kernel = layer.kernel + self._get_noise(
self.noise_distribution,
data_size
)
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
n_samples=None,
**kwargs):
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
Requirements are as follows:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
See super implementation for more details.
Args:
n_samples: the number of individual samples in x.
Returns:
"""
self.__in_fit = True
cb = [self._callback]
if callbacks is not None:
cb.extend(callbacks)
callbacks = cb
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit(x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
self._post_fit(x, n_samples)
self.__in_fit = False
return out
def fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
validation_freq=1,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0,
n_samples=None
):
"""
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
n_samples: number of individual samples in x
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit_generator(
generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle,
initial_epoch=initial_epoch
)
if not self.__in_fit:
self._post_fit(generator, n_samples)
return out
def calculate_class_weights(self,
class_weights=None,
class_counts=None,
num_classes=None
):
"""
Calculates class weighting to be used in training. Can be on
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then the number of
samples for each class
num_classes: If class_weights is not None, then the number of
classes.
Returns: class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
is_string = False
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError("Detected string class_weights with "
"value: {0}, which is not one of {1}."
"Please select a valid class_weight type"
"or pass an array".format(class_weights,
class_keys))
if class_counts is None:
raise ValueError("Class counts must be provided if using"
"class_weights=%s" % class_weights)
if num_classes is None:
raise ValueError("Class counts must be provided if using"
"class_weights=%s" % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError("You must pass a value for num_classes if"
"creating an array of class_weights")
# performing class weight calculation
if class_weights is None:
class_weights = 1
elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts)
class_weights = tf.Variable(
num_samples / (num_classes * class_counts),
dtype=self._dtype
)
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError("Detected class_weights shape: {0} instead of "
"1D array".format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
)
return class_weights
def _project_weights_to_r(self, r, force=False):
"""helper method to normalize the weights to the R-ball.
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Returns:
"""
for layer in self._layers:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
def _get_noise(self, distribution, data_size):
"""Sample noise to be added to weights for privacy guarantee
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
Returns: noise in shape of layer's weights to be added to the weights.
"""
distribution = distribution.lower()
input_dim = self._layers[0].kernel.numpy().shape[0]
loss = self.loss
if distribution == 'laplace':
per_class_epsilon = self.epsilon / (self.n_classes)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
mean=0,
seed=1,
stddev=1.0,
dtype=self._dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([self.n_classes],
alpha,
beta=1 / beta,
seed=1,
dtype=self._dtype)
return unit_vector * gamma
raise NotImplementedError("distribution: {0} is not "
"currently supported".format(distribution))