2019-07-16 08:33:57 -06:00
|
|
|
# Copyright 2019, The TensorFlow Authors.
|
2019-06-05 15:06:02 -06:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2019-07-31 08:55:25 -06:00
|
|
|
"""BoltOn model for Bolt-on method of differentially private ML."""
|
2019-06-05 15:06:02 -06:00
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.python.framework import ops as _ops
|
2019-07-25 09:37:54 -06:00
|
|
|
from tensorflow.python.keras import optimizers
|
|
|
|
from tensorflow.python.keras.models import Model
|
2019-07-31 08:52:41 -06:00
|
|
|
from privacy.bolt_on.losses import StrongConvexMixin
|
|
|
|
from privacy.bolt_on.optimizers import BoltOn
|
2019-06-05 15:06:02 -06:00
|
|
|
|
|
|
|
|
2019-07-30 13:12:22 -06:00
|
|
|
class BoltOnModel(Model): # pylint: disable=abstract-method
|
|
|
|
"""BoltOn episilon-delta differential privacy model.
|
2019-07-16 08:33:57 -06:00
|
|
|
|
|
|
|
The privacy guarantees are dependent on the noise that is sampled. Please
|
|
|
|
see the paper linked below for more details.
|
|
|
|
|
2019-06-05 15:06:02 -06:00
|
|
|
Uses 4 key steps to achieve privacy guarantees:
|
|
|
|
1. Adds noise to weights after training (output perturbation).
|
|
|
|
2. Projects weights to R after each batch
|
|
|
|
3. Limits learning rate
|
|
|
|
4. Use a strongly convex loss function (see compile)
|
2019-06-10 14:11:47 -06:00
|
|
|
|
|
|
|
For more details on the strong convexity requirements, see:
|
|
|
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
2019-07-22 09:27:53 -06:00
|
|
|
Descent-based Analytics by Xi Wu et al.
|
2019-06-05 15:06:02 -06:00
|
|
|
"""
|
2019-06-10 14:11:47 -06:00
|
|
|
|
2019-06-05 15:06:02 -06:00
|
|
|
def __init__(self,
|
2019-06-17 11:25:30 -06:00
|
|
|
n_outputs,
|
2019-06-05 15:06:02 -06:00
|
|
|
seed=1,
|
2019-07-25 09:37:54 -06:00
|
|
|
dtype=tf.float32):
|
|
|
|
"""Private constructor.
|
2019-06-05 15:06:02 -06:00
|
|
|
|
|
|
|
Args:
|
2019-06-17 11:25:30 -06:00
|
|
|
n_outputs: number of output classes to predict.
|
2019-06-05 15:06:02 -06:00
|
|
|
seed: random seed to use
|
|
|
|
dtype: data type to use for tensors
|
|
|
|
"""
|
2019-07-30 13:12:22 -06:00
|
|
|
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
|
2019-06-17 11:25:30 -06:00
|
|
|
if n_outputs <= 0:
|
|
|
|
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
|
|
|
|
n_outputs
|
|
|
|
))
|
|
|
|
self.n_outputs = n_outputs
|
2019-06-05 15:06:02 -06:00
|
|
|
self.seed = seed
|
2019-06-10 14:11:47 -06:00
|
|
|
self._layers_instantiated = False
|
2019-06-05 15:06:02 -06:00
|
|
|
self._dtype = dtype
|
|
|
|
|
2019-06-19 09:18:42 -06:00
|
|
|
def call(self, inputs): # pylint: disable=arguments-differ
|
2019-07-25 09:37:54 -06:00
|
|
|
"""Forward pass of network.
|
2019-06-05 15:06:02 -06:00
|
|
|
|
|
|
|
Args:
|
|
|
|
inputs: inputs to neural network
|
|
|
|
|
|
|
|
Returns:
|
2019-07-30 13:12:22 -06:00
|
|
|
Output logits for the given inputs.
|
2019-06-05 15:06:02 -06:00
|
|
|
|
|
|
|
"""
|
|
|
|
return self.output_layer(inputs)
|
|
|
|
|
|
|
|
def compile(self,
|
2019-06-17 11:25:30 -06:00
|
|
|
optimizer,
|
|
|
|
loss,
|
|
|
|
kernel_initializer=tf.initializers.GlorotUniform,
|
|
|
|
**kwargs): # pylint: disable=arguments-differ
|
2019-07-30 13:12:22 -06:00
|
|
|
"""See super class. Default optimizer used in BoltOn method is SGD.
|
2019-06-05 15:06:02 -06:00
|
|
|
|
2019-07-29 16:15:52 -06:00
|
|
|
Args:
|
|
|
|
optimizer: The optimizer to use. This will be automatically wrapped
|
2019-07-30 13:12:22 -06:00
|
|
|
with the BoltOn Optimizer.
|
2019-07-29 16:15:52 -06:00
|
|
|
loss: The loss function to use. Must be a StrongConvex loss (extend the
|
|
|
|
StrongConvexMixin).
|
|
|
|
kernel_initializer: The kernel initializer to use for the single layer.
|
2019-07-29 16:22:52 -06:00
|
|
|
**kwargs: kwargs to keras Model.compile. See super.
|
2019-06-05 15:06:02 -06:00
|
|
|
"""
|
2019-06-17 11:25:30 -06:00
|
|
|
if not isinstance(loss, StrongConvexMixin):
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('loss function must be a Strongly Convex and therefore '
|
|
|
|
'extend the StrongConvexMixin.')
|
2019-06-10 14:11:47 -06:00
|
|
|
if not self._layers_instantiated: # compile may be called multiple times
|
2019-06-17 11:25:30 -06:00
|
|
|
# for instance, if the input/outputs are not defined until fit.
|
2019-06-10 14:11:47 -06:00
|
|
|
self.output_layer = tf.keras.layers.Dense(
|
2019-06-17 11:25:30 -06:00
|
|
|
self.n_outputs,
|
2019-06-10 14:11:47 -06:00
|
|
|
kernel_regularizer=loss.kernel_regularizer(),
|
2019-06-17 11:25:30 -06:00
|
|
|
kernel_initializer=kernel_initializer(),
|
2019-06-10 14:11:47 -06:00
|
|
|
)
|
|
|
|
self._layers_instantiated = True
|
2019-07-30 13:12:22 -06:00
|
|
|
if not isinstance(optimizer, BoltOn):
|
2019-06-05 15:06:02 -06:00
|
|
|
optimizer = optimizers.get(optimizer)
|
2019-07-30 13:12:22 -06:00
|
|
|
optimizer = BoltOn(optimizer, loss)
|
2019-06-12 23:01:31 -06:00
|
|
|
|
2019-07-30 13:12:22 -06:00
|
|
|
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
|
2019-06-12 23:01:31 -06:00
|
|
|
|
2019-06-05 15:06:02 -06:00
|
|
|
def fit(self,
|
|
|
|
x=None,
|
|
|
|
y=None,
|
|
|
|
batch_size=None,
|
|
|
|
class_weight=None,
|
|
|
|
n_samples=None,
|
2019-06-12 23:01:31 -06:00
|
|
|
epsilon=2,
|
|
|
|
noise_distribution='laplace',
|
2019-06-17 11:25:30 -06:00
|
|
|
steps_per_epoch=None,
|
|
|
|
**kwargs): # pylint: disable=arguments-differ
|
2019-07-30 13:12:22 -06:00
|
|
|
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
|
2019-07-16 08:33:57 -06:00
|
|
|
|
2019-07-29 16:09:21 -06:00
|
|
|
Note, inputs must be normalized s.t. ||x|| < 1.
|
|
|
|
Requirements are as follows:
|
|
|
|
1. Adds noise to weights after training (output perturbation).
|
|
|
|
2. Projects weights to R after each batch
|
|
|
|
3. Limits learning rate
|
|
|
|
4. Use a strongly convex loss function (see compile)
|
|
|
|
See super implementation for more details.
|
2019-06-05 15:06:02 -06:00
|
|
|
|
2019-07-29 16:09:21 -06:00
|
|
|
Args:
|
2019-07-30 13:12:22 -06:00
|
|
|
x: Inputs to fit on, see super.
|
|
|
|
y: Labels to fit on, see super.
|
|
|
|
batch_size: The batch size to use for training, see super.
|
2019-07-29 16:09:21 -06:00
|
|
|
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
|
|
|
whose dim == n_classes.
|
|
|
|
n_samples: the number of individual samples in x.
|
|
|
|
epsilon: privacy parameter, which trades off between utility an privacy.
|
2019-07-31 08:55:25 -06:00
|
|
|
See the bolt-on paper for more description.
|
2019-07-29 16:09:21 -06:00
|
|
|
noise_distribution: the distribution to pull noise from.
|
|
|
|
steps_per_epoch:
|
2019-07-29 16:22:52 -06:00
|
|
|
**kwargs: kwargs to keras Model.fit. See super.
|
2019-07-29 16:09:21 -06:00
|
|
|
|
|
|
|
Returns:
|
2019-07-30 13:12:22 -06:00
|
|
|
Output from super fit method.
|
2019-06-05 15:06:02 -06:00
|
|
|
"""
|
|
|
|
if class_weight is None:
|
2019-06-19 08:46:30 -06:00
|
|
|
class_weight_ = self.calculate_class_weights(class_weight)
|
|
|
|
else:
|
|
|
|
class_weight_ = class_weight
|
2019-06-17 11:25:30 -06:00
|
|
|
if n_samples is not None:
|
|
|
|
data_size = n_samples
|
|
|
|
elif hasattr(x, 'shape'):
|
|
|
|
data_size = x.shape[0]
|
2019-07-25 10:13:32 -06:00
|
|
|
elif hasattr(x, '__len__'):
|
2019-06-17 11:25:30 -06:00
|
|
|
data_size = len(x)
|
|
|
|
else:
|
|
|
|
data_size = None
|
|
|
|
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
|
|
|
steps_per_epoch,
|
2019-07-25 09:37:54 -06:00
|
|
|
x)
|
2019-08-06 09:00:22 -06:00
|
|
|
if batch_size_ is None:
|
|
|
|
batch_size_ = 32
|
2019-06-17 11:25:30 -06:00
|
|
|
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
|
|
|
# initial value when passed to super().fit()
|
|
|
|
if batch_size_ is None:
|
|
|
|
raise ValueError('batch_size: {0} is an '
|
|
|
|
'invalid value'.format(batch_size_))
|
2019-06-19 08:46:30 -06:00
|
|
|
if data_size is None:
|
|
|
|
raise ValueError('Could not infer the number of samples. Please pass '
|
|
|
|
'this in using n_samples.')
|
2019-06-12 23:01:31 -06:00
|
|
|
with self.optimizer(noise_distribution,
|
|
|
|
epsilon,
|
|
|
|
self.layers,
|
2019-06-19 08:46:30 -06:00
|
|
|
class_weight_,
|
2019-06-17 11:25:30 -06:00
|
|
|
data_size,
|
2019-07-25 09:37:54 -06:00
|
|
|
batch_size_) as _:
|
2019-07-30 13:12:22 -06:00
|
|
|
out = super(BoltOnModel, self).fit(x=x,
|
2019-06-12 23:01:31 -06:00
|
|
|
y=y,
|
|
|
|
batch_size=batch_size,
|
|
|
|
class_weight=class_weight,
|
|
|
|
steps_per_epoch=steps_per_epoch,
|
2019-07-25 09:37:54 -06:00
|
|
|
**kwargs)
|
2019-06-05 15:06:02 -06:00
|
|
|
return out
|
|
|
|
|
|
|
|
def fit_generator(self,
|
|
|
|
generator,
|
|
|
|
class_weight=None,
|
2019-06-17 11:25:30 -06:00
|
|
|
noise_distribution='laplace',
|
|
|
|
epsilon=2,
|
|
|
|
n_samples=None,
|
|
|
|
steps_per_epoch=None,
|
2019-07-25 09:37:54 -06:00
|
|
|
**kwargs): # pylint: disable=arguments-differ
|
2019-07-29 15:20:40 -06:00
|
|
|
"""Fit with a generator.
|
|
|
|
|
2019-07-29 16:09:21 -06:00
|
|
|
This method is the same as fit except for when the passed dataset
|
|
|
|
is a generator. See super method and fit for more details.
|
2019-07-29 15:34:02 -06:00
|
|
|
|
2019-07-29 16:09:21 -06:00
|
|
|
Args:
|
2019-07-30 13:12:22 -06:00
|
|
|
generator: Inputs generator following Tensorflow guidelines, see super.
|
2019-07-29 16:09:21 -06:00
|
|
|
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
|
|
|
whose dim == n_classes.
|
|
|
|
noise_distribution: the distribution to get noise from.
|
|
|
|
epsilon: privacy parameter, which trades off utility and privacy. See
|
2019-07-30 13:12:22 -06:00
|
|
|
BoltOn paper for more description.
|
2019-07-29 16:09:21 -06:00
|
|
|
n_samples: number of individual samples in x
|
2019-07-30 13:12:22 -06:00
|
|
|
steps_per_epoch: Number of steps per training epoch, see super.
|
2019-07-29 16:15:52 -06:00
|
|
|
**kwargs: **kwargs
|
2019-07-29 16:22:52 -06:00
|
|
|
|
|
|
|
Returns:
|
2019-07-30 13:12:22 -06:00
|
|
|
Output from super fit_generator method.
|
2019-06-05 15:06:02 -06:00
|
|
|
"""
|
|
|
|
if class_weight is None:
|
|
|
|
class_weight = self.calculate_class_weights(class_weight)
|
2019-06-17 11:25:30 -06:00
|
|
|
if n_samples is not None:
|
|
|
|
data_size = n_samples
|
|
|
|
elif hasattr(generator, 'shape'):
|
|
|
|
data_size = generator.shape[0]
|
2019-07-25 10:13:32 -06:00
|
|
|
elif hasattr(generator, '__len__'):
|
2019-06-17 11:25:30 -06:00
|
|
|
data_size = len(generator)
|
|
|
|
else:
|
2019-08-23 09:06:11 -06:00
|
|
|
raise ValueError('The number of samples could not be determined. '
|
|
|
|
'Please make sure that if you are using a generator'
|
|
|
|
'to call this method directly with n_samples kwarg '
|
|
|
|
'passed.')
|
|
|
|
batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch,
|
2019-07-25 09:37:54 -06:00
|
|
|
generator)
|
2019-08-06 09:00:22 -06:00
|
|
|
if batch_size is None:
|
|
|
|
batch_size = 32
|
2019-06-17 11:25:30 -06:00
|
|
|
with self.optimizer(noise_distribution,
|
|
|
|
epsilon,
|
|
|
|
self.layers,
|
|
|
|
class_weight,
|
|
|
|
data_size,
|
2019-07-25 09:37:54 -06:00
|
|
|
batch_size) as _:
|
2019-07-30 13:12:22 -06:00
|
|
|
out = super(BoltOnModel, self).fit_generator(
|
2019-06-17 11:25:30 -06:00
|
|
|
generator,
|
|
|
|
class_weight=class_weight,
|
|
|
|
steps_per_epoch=steps_per_epoch,
|
2019-07-25 09:37:54 -06:00
|
|
|
**kwargs)
|
2019-06-05 15:06:02 -06:00
|
|
|
return out
|
|
|
|
|
|
|
|
def calculate_class_weights(self,
|
|
|
|
class_weights=None,
|
|
|
|
class_counts=None,
|
2019-07-25 09:37:54 -06:00
|
|
|
num_classes=None):
|
2019-07-16 08:33:57 -06:00
|
|
|
"""Calculates class weighting to be used in training.
|
|
|
|
|
2019-07-29 16:09:21 -06:00
|
|
|
Args:
|
|
|
|
class_weights: str specifying type, array giving weights, or None.
|
|
|
|
class_counts: If class_weights is not None, then an array of
|
|
|
|
the number of samples for each class
|
|
|
|
num_classes: If class_weights is not None, then the number of
|
|
|
|
classes.
|
|
|
|
Returns:
|
|
|
|
class_weights as 1D tensor, to be passed to model's fit method.
|
2019-06-05 15:06:02 -06:00
|
|
|
"""
|
|
|
|
# Value checking
|
|
|
|
class_keys = ['balanced']
|
|
|
|
is_string = False
|
|
|
|
if isinstance(class_weights, str):
|
|
|
|
is_string = True
|
|
|
|
if class_weights not in class_keys:
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('Detected string class_weights with '
|
|
|
|
'value: {0}, which is not one of {1}.'
|
|
|
|
'Please select a valid class_weight type'
|
|
|
|
'or pass an array'.format(class_weights,
|
2019-06-05 15:06:02 -06:00
|
|
|
class_keys))
|
|
|
|
if class_counts is None:
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('Class counts must be provided if using '
|
|
|
|
'class_weights=%s' % class_weights)
|
2019-06-10 14:11:47 -06:00
|
|
|
class_counts_shape = tf.Variable(class_counts,
|
|
|
|
trainable=False,
|
|
|
|
dtype=self._dtype).shape
|
|
|
|
if len(class_counts_shape) != 1:
|
|
|
|
raise ValueError('class counts must be a 1D array.'
|
|
|
|
'Detected: {0}'.format(class_counts_shape))
|
2019-06-05 15:06:02 -06:00
|
|
|
if num_classes is None:
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('num_classes must be provided if using '
|
|
|
|
'class_weights=%s' % class_weights)
|
2019-06-05 15:06:02 -06:00
|
|
|
elif class_weights is not None:
|
|
|
|
if num_classes is None:
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('You must pass a value for num_classes if '
|
|
|
|
'creating an array of class_weights')
|
2019-06-05 15:06:02 -06:00
|
|
|
# performing class weight calculation
|
|
|
|
if class_weights is None:
|
|
|
|
class_weights = 1
|
|
|
|
elif is_string and class_weights == 'balanced':
|
|
|
|
num_samples = sum(class_counts)
|
2019-06-10 14:11:47 -06:00
|
|
|
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
|
2019-07-25 09:37:54 -06:00
|
|
|
class_counts),
|
|
|
|
self._dtype)
|
2019-06-10 14:11:47 -06:00
|
|
|
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
|
|
|
|
tf.Variable(weighted_counts, dtype=self._dtype)
|
2019-06-05 15:06:02 -06:00
|
|
|
else:
|
|
|
|
class_weights = _ops.convert_to_tensor_v2(class_weights)
|
|
|
|
if len(class_weights.shape) != 1:
|
2019-07-25 10:13:32 -06:00
|
|
|
raise ValueError('Detected class_weights shape: {0} instead of '
|
|
|
|
'1D array'.format(class_weights.shape))
|
2019-06-05 15:06:02 -06:00
|
|
|
if class_weights.shape[0] != num_classes:
|
|
|
|
raise ValueError(
|
2019-07-25 10:13:32 -06:00
|
|
|
'Detected array length: {0} instead of: {1}'.format(
|
2019-06-17 11:25:30 -06:00
|
|
|
class_weights.shape[0],
|
2019-07-25 09:37:54 -06:00
|
|
|
num_classes))
|
2019-06-05 15:06:02 -06:00
|
|
|
return class_weights
|