Working bolton model without unit tests.

This commit is contained in:
Christopher Choquette Choo 2019-06-05 17:06:02 -04:00
parent d5dcfec745
commit 5f46927747
7 changed files with 884 additions and 0 deletions

View file

@ -0,0 +1,14 @@
import sys
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
raise ImportError("Please upgrade your version of tensorflow from: {0} "
"to at least 2.0.0 to use privacy/bolton".format(
LooseVersion(tf.__version__)))
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
pass
else:
from privacy.bolton.model import Bolton
from privacy.bolton.loss import Huber
from privacy.bolton.loss import BinaryCrossentropy

280
privacy/bolton/loss.py Normal file
View file

@ -0,0 +1,280 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions for bolton method"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import losses
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.framework import ops as _ops
class StrongConvexLoss(losses.Loss):
"""
Strong Convex Loss base class for any loss function that will be used with
Bolton model. Subclasses must be strongly convex and implement the
associated constants. They must also conform to the requirements of tf losses
(see super class)
"""
def __init__(self,
reg_lambda: float,
c: float,
radius_constant: float = 1,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = None,
dtype=tf.float32,
**kwargs):
"""
Args:
reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts
as a global weight.
radius_constant: constant defining the length of the radius
reduction: reduction type to use. See super class
name: Name of the loss instance
dtype: tf datatype to use for tensor conversions.
"""
super(StrongConvexLoss, self).__init__(reduction=reduction,
name=name,
**kwargs)
self._sample_weight = tf.Variable(initial_value=c,
trainable=False,
dtype=tf.float32)
self._reg_lambda = reg_lambda
self.radius_constant = tf.Variable(initial_value=radius_constant,
trainable=False,
dtype=tf.float32)
self.dtype = dtype
def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch)
Returns: radius
"""
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def gamma(self):
""" Gamma strongly convex
Returns: gamma
"""
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def beta(self, class_weight):
"""Beta smoothess
Args:
class_weight: the class weights used.
Returns: Beta
"""
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def lipchitz_constant(self, class_weight):
""" L lipchitz continuous
Args:
class_weight: class weights used
Returns: L
"""
raise NotImplementedError("lipchitz constant not implemented for "
"StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def reg_lambda(self, convert_to_tensor: bool = False):
""" returns the lambda weight regularization constant, as a tensor if
desired
Args:
convert_to_tensor: True to convert to tensor, False to leave as
python numeric.
Returns: reg_lambda
"""
if convert_to_tensor:
return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype)
return self._reg_lambda
def max_class_weight(self, class_weight):
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype)
return tf.math.reduce_max(class_weight)
class Huber(StrongConvexLoss, losses.Huber):
"""Strong Convex version of huber loss using l2 weight regularization.
"""
def __init__(self,
reg_lambda: float,
c: float,
radius_constant: float,
delta: float,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = 'huber',
dtype=tf.float32):
"""Constructor. Passes arguments to StrongConvexLoss and Huber Loss.
Args:
reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts
as a global weight.
radius_constant: constant defining the length of the radius
delta: delta value in huber loss. When to switch from quadratic to
absolute deviation.
reduction: reduction type to use. See super class
name: Name of the loss instance
dtype: tf datatype to use for tensor conversions.
Returns:
Loss values per sample.
"""
# self.delta = tf.Variable(initial_value=delta, trainable=False)
super(Huber, self).__init__(
reg_lambda,
c,
radius_constant,
delta=delta,
name=name,
reduction=reduction,
dtype=dtype
)
def call(self, y_true, y_pred):
"""Compute loss
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \
self._sample_weight
def radius(self):
"""See super class.
"""
return self.radius_constant / self.reg_lambda(True)
def gamma(self):
"""See super class.
"""
return self.reg_lambda(True)
def beta(self, class_weight):
"""See super class.
"""
max_class_weight = self.max_class_weight(class_weight)
return self._sample_weight * max_class_weight / \
(self.delta * tf.Variable(initial_value=2, trainable=False)) + \
self.reg_lambda(True)
def lipchitz_constant(self, class_weight):
"""See super class.
"""
# if class_weight is provided,
# it should be a vector of the same size of number of classes
max_class_weight = self.max_class_weight(class_weight)
lc = self._sample_weight * max_class_weight + \
self.reg_lambda(True) * self.radius()
return lc
class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
"""
Strong Convex version of BinaryCrossentropy loss using l2 weight
regularization.
"""
def __init__(self,
reg_lambda: float,
c: float,
radius_constant: float,
from_logits: bool = True,
label_smoothing: float = 0,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = 'binarycrossentropy',
dtype=tf.float32):
"""
Args:
reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts
as a global weight.
radius_constant: constant defining the length of the radius
reduction: reduction type to use. See super class
label_smoothing: amount of smoothing to perform on labels
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
name: Name of the loss instance
dtype: tf datatype to use for tensor conversions.
"""
super(BinaryCrossentropy, self).__init__(reg_lambda,
c,
radius_constant,
reduction=reduction,
name=name,
from_logits=from_logits,
label_smoothing=label_smoothing,
dtype=dtype
)
self.radius_constant = radius_constant
def call(self, y_true, y_pred):
"""Compute loss
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=y_true,
logits=y_pred
)
loss = loss * self._sample_weight
return loss
def radius(self):
"""See super class.
"""
return self.radius_constant / self.reg_lambda(True)
def gamma(self):
"""See super class.
"""
return self.reg_lambda(True)
def beta(self, class_weight):
"""See super class.
"""
max_class_weight = self.max_class_weight(class_weight)
return self._sample_weight * max_class_weight + self.reg_lambda(True)
def lipchitz_constant(self, class_weight):
"""See super class.
"""
max_class_weight = self.max_class_weight(class_weight)
return self._sample_weight * max_class_weight + \
self.reg_lambda(True) * self.radius()

View file

@ -0,0 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

402
privacy/bolton/model.py Normal file
View file

@ -0,0 +1,402 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton model for bolton method of differentially private ML"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexLoss
from privacy.bolton.optimizer import Private
class Bolton(Model):
"""
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
"""
def __init__(self,
n_classes,
epsilon,
noise_distribution='laplace',
weights_initializer=tf.initializers.GlorotUniform(),
seed=1,
dtype=tf.float32
):
""" private constructor.
Args:
n_classes: number of output classes to predict.
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
seed: random seed to use
dtype: data type to use for tensors
"""
class MyCustomCallback(tf.keras.callbacks.Callback):
"""Custom callback for bolton training requirements.
Implements steps (see Bolton class):
2. Projects weights to R after each batch
3. Limits learning rate
"""
def on_train_batch_end(self, batch, logs=None):
loss = self.model.loss
self.model.optimizer.limit_learning_rate(
self.model.run_eagerly,
loss.beta(self.model.class_weight),
loss.gamma()
)
self.model._project_weights_to_r(loss.radius(), False)
def on_train_end(self, logs=None):
loss = self.model.loss
self.model._project_weights_to_r(loss.radius(), True)
super(Bolton, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
kernel_regularizer=tf.keras.regularizers.l2(),
kernel_initializer=weights_initializer,
)
# if we do regularization here, we require the user to re-instantiate
# the model each time they want to
# change lambda, unless we standardize modifying it later at .compile
self.force = False
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.seed = seed
self.__in_fit = False
self._callback = MyCustomCallback()
self._dtype = dtype
def call(self, inputs):
"""Forward pass of network
Args:
inputs: inputs to neural network
Returns:
"""
return self.output_layer(inputs)
def compile(self,
optimizer='SGD',
loss=None,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
**kwargs):
"""See super class. Default optimizer used in Bolton method is SGD.
"""
if not isinstance(loss, StrongConvexLoss):
raise ValueError("Loss must be subclassed from StrongConvexLoss")
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda()
if not isinstance(optimizer, Private):
optimizer = optimizers.get(optimizer)
if isinstance(self.optimizer, trackable.Trackable):
self._track_trackable(
self.optimizer, name='optimizer', overwrite=True
)
optimizer = Private(optimizer)
super(Bolton, self).compile(optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs
)
def _post_fit(self, x, n_samples):
"""Implements 1-time weight changes needed for Bolton method.
In this case, specifically implements the noise addition
assuming a strongly convex function.
Args:
x: inputs
n_samples: number of samples in the inputs. In case the number
cannot be readily determined by inspecting x.
Returns:
"""
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, "__len__"):
data_size = len(x)
else:
if n_samples is None:
raise ValueError("Unable to detect the number of training "
"samples and n_smaples was None. "
"either pass a dataset with a .shape or "
"__len__ attribute or explicitly pass the "
"number of samples as n_smaples.")
data_size = n_samples
for layer in self._layers:
layer.kernel = layer.kernel + self._get_noise(
self.noise_distribution,
data_size
)
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
n_samples=None,
**kwargs):
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
Requirements are as follows:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
See super implementation for more details.
Args:
n_samples: the number of individual samples in x.
Returns:
"""
self.__in_fit = True
cb = [self._callback]
if callbacks is not None:
cb.extend(callbacks)
callbacks = cb
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit(x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
self._post_fit(x, n_samples)
self.__in_fit = False
return out
def fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
validation_freq=1,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0,
n_samples=None
):
"""
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
n_samples: number of individual samples in x
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit_generator(
generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle,
initial_epoch=initial_epoch
)
if not self.__in_fit:
self._post_fit(generator, n_samples)
return out
def calculate_class_weights(self,
class_weights=None,
class_counts=None,
num_classes=None
):
"""
Calculates class weighting to be used in training. Can be on
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then the number of
samples for each class
num_classes: If class_weights is not None, then the number of
classes.
Returns: class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
is_string = False
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError("Detected string class_weights with "
"value: {0}, which is not one of {1}."
"Please select a valid class_weight type"
"or pass an array".format(class_weights,
class_keys))
if class_counts is None:
raise ValueError("Class counts must be provided if using"
"class_weights=%s" % class_weights)
if num_classes is None:
raise ValueError("Class counts must be provided if using"
"class_weights=%s" % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError("You must pass a value for num_classes if"
"creating an array of class_weights")
# performing class weight calculation
if class_weights is None:
class_weights = 1
elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts)
class_weights = tf.Variable(
num_samples / (num_classes * class_counts),
dtype=self._dtype
)
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError("Detected class_weights shape: {0} instead of "
"1D array".format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
)
return class_weights
def _project_weights_to_r(self, r, force=False):
"""helper method to normalize the weights to the R-ball.
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Returns:
"""
for layer in self._layers:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
def _get_noise(self, distribution, data_size):
"""Sample noise to be added to weights for privacy guarantee
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
Returns: noise in shape of layer's weights to be added to the weights.
"""
distribution = distribution.lower()
input_dim = self._layers[0].kernel.numpy().shape[0]
loss = self.loss
if distribution == 'laplace':
per_class_epsilon = self.epsilon / (self.n_classes)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
mean=0,
seed=1,
stddev=1.0,
dtype=self._dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([self.n_classes],
alpha,
beta=1 / beta,
seed=1,
dtype=self._dtype)
return unit_vector * gamma
raise NotImplementedError("distribution: {0} is not "
"currently supported".format(distribution))

View file

@ -0,0 +1,3 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

173
privacy/bolton/optimizer.py Normal file
View file

@ -0,0 +1,173 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private Optimizer for bolton method"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
_private_attributes = ['_internal_optimizer', 'dtype']
class Private(optimizer_v2.OptimizerV2):
"""
Private optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "Private" enables the bolton model to control the learning rate
based on the strongly convex loss.
"""
def __init__(self,
optimizer: optimizer_v2.OptimizerV2,
dtype=tf.float32
):
"""Constructor.
Args:
optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped).
"""
self._internal_optimizer = optimizer
self.dtype = dtype
def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.get_config()
def limit_learning_rate(self, is_eager, beta, gamma):
"""Implements learning rate limitation that is required by the bolton
method for sensitivity bounding of the strongly convex function.
Sets the learning rate to the min(1/beta, 1/(gamma*t))
Args:
is_eager: Whether the model is running in eager mode
beta: loss function beta-smoothness
gamma: loss function gamma-strongly convex
Returns: None
"""
numerator = tf.Variable(initial_value=1, dtype=self.dtype)
t = tf.cast(self._iterations, self.dtype)
# will exist on the internal optimizer
pred = numerator / beta < numerator / (gamma * t)
if is_eager: # check eagerly
if pred:
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
else:
if pred:
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
def from_config(self, config, custom_objects=None):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.from_config(
config,
custom_objects=custom_objects
)
def __getattr__(self, name):
"""return _internal_optimizer off self instance, and everything else
from the _internal_optimizer instance.
Args:
name:
Returns: attribute from Private if specified to come from self, else
from _internal_optimizer.
"""
if name in _private_attributes:
return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer')
return object.__getattribute__(optim, name)
def __setattr__(self, key, value):
""" Set attribute to self instance if its the internal optimizer.
Reroute everything else to the _internal_optimizer.
Args:
key: attribute name
value: attribute value
Returns:
"""
if key in _private_attributes:
object.__setattr__(self, key, value)
else:
setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, grad, handle):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer._resource_apply_dense(grad, handle)
def _resource_apply_sparse(self, grad, handle, indices):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer._resource_apply_sparse(
grad,
handle,
indices
)
def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.get_updates(loss, params)
def apply_gradients(self, grads_and_vars, name: str = None):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.apply_gradients(
grads_and_vars,
name=name
)
def minimize(self,
loss,
var_list,
grad_loss: bool = None,
name: str = None
):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.minimize(
loss,
var_list,
grad_loss,
name
)
def _compute_gradients(self, loss, var_list, grad_loss=None):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer._compute_gradients(
loss,
var_list,
grad_loss=grad_loss
)
def get_gradients(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.get_gradients(loss, params)

View file

@ -0,0 +1,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized
from privacy.bolton import model