Working bolton model without unit tests.
This commit is contained in:
parent
d5dcfec745
commit
5f46927747
7 changed files with 884 additions and 0 deletions
14
privacy/bolton/__init__.py
Normal file
14
privacy/bolton/__init__.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import sys
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
|
||||||
|
raise ImportError("Please upgrade your version of tensorflow from: {0} "
|
||||||
|
"to at least 2.0.0 to use privacy/bolton".format(
|
||||||
|
LooseVersion(tf.__version__)))
|
||||||
|
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from privacy.bolton.model import Bolton
|
||||||
|
from privacy.bolton.loss import Huber
|
||||||
|
from privacy.bolton.loss import BinaryCrossentropy
|
280
privacy/bolton/loss.py
Normal file
280
privacy/bolton/loss.py
Normal file
|
@ -0,0 +1,280 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Loss functions for bolton method"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras import losses
|
||||||
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
|
from tensorflow.python.framework import ops as _ops
|
||||||
|
|
||||||
|
|
||||||
|
class StrongConvexLoss(losses.Loss):
|
||||||
|
"""
|
||||||
|
Strong Convex Loss base class for any loss function that will be used with
|
||||||
|
Bolton model. Subclasses must be strongly convex and implement the
|
||||||
|
associated constants. They must also conform to the requirements of tf losses
|
||||||
|
(see super class)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reg_lambda: float,
|
||||||
|
c: float,
|
||||||
|
radius_constant: float = 1,
|
||||||
|
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
name: str = None,
|
||||||
|
dtype=tf.float32,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
reg_lambda: Weight regularization constant
|
||||||
|
c: Additional constant for strongly convex convergence. Acts
|
||||||
|
as a global weight.
|
||||||
|
radius_constant: constant defining the length of the radius
|
||||||
|
reduction: reduction type to use. See super class
|
||||||
|
name: Name of the loss instance
|
||||||
|
dtype: tf datatype to use for tensor conversions.
|
||||||
|
"""
|
||||||
|
super(StrongConvexLoss, self).__init__(reduction=reduction,
|
||||||
|
name=name,
|
||||||
|
**kwargs)
|
||||||
|
self._sample_weight = tf.Variable(initial_value=c,
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.float32)
|
||||||
|
self._reg_lambda = reg_lambda
|
||||||
|
self.radius_constant = tf.Variable(initial_value=radius_constant,
|
||||||
|
trainable=False,
|
||||||
|
dtype=tf.float32)
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""Radius of R-Ball (value to normalize weights to after each batch)
|
||||||
|
|
||||||
|
Returns: radius
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
""" Gamma strongly convex
|
||||||
|
|
||||||
|
Returns: gamma
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""Beta smoothess
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: the class weights used.
|
||||||
|
|
||||||
|
Returns: Beta
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
""" L lipchitz continuous
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
|
||||||
|
Returns: L
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("lipchitz constant not implemented for "
|
||||||
|
"StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def reg_lambda(self, convert_to_tensor: bool = False):
|
||||||
|
""" returns the lambda weight regularization constant, as a tensor if
|
||||||
|
desired
|
||||||
|
|
||||||
|
Args:
|
||||||
|
convert_to_tensor: True to convert to tensor, False to leave as
|
||||||
|
python numeric.
|
||||||
|
|
||||||
|
Returns: reg_lambda
|
||||||
|
|
||||||
|
"""
|
||||||
|
if convert_to_tensor:
|
||||||
|
return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype)
|
||||||
|
return self._reg_lambda
|
||||||
|
|
||||||
|
def max_class_weight(self, class_weight):
|
||||||
|
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype)
|
||||||
|
return tf.math.reduce_max(class_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class Huber(StrongConvexLoss, losses.Huber):
|
||||||
|
"""Strong Convex version of huber loss using l2 weight regularization.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reg_lambda: float,
|
||||||
|
c: float,
|
||||||
|
radius_constant: float,
|
||||||
|
delta: float,
|
||||||
|
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
name: str = 'huber',
|
||||||
|
dtype=tf.float32):
|
||||||
|
"""Constructor. Passes arguments to StrongConvexLoss and Huber Loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: Weight regularization constant
|
||||||
|
c: Additional constant for strongly convex convergence. Acts
|
||||||
|
as a global weight.
|
||||||
|
radius_constant: constant defining the length of the radius
|
||||||
|
delta: delta value in huber loss. When to switch from quadratic to
|
||||||
|
absolute deviation.
|
||||||
|
reduction: reduction type to use. See super class
|
||||||
|
name: Name of the loss instance
|
||||||
|
dtype: tf datatype to use for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
# self.delta = tf.Variable(initial_value=delta, trainable=False)
|
||||||
|
super(Huber, self).__init__(
|
||||||
|
reg_lambda,
|
||||||
|
c,
|
||||||
|
radius_constant,
|
||||||
|
delta=delta,
|
||||||
|
name=name,
|
||||||
|
reduction=reduction,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Compute loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values.
|
||||||
|
y_pred: The predicted values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \
|
||||||
|
self._sample_weight
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
return self.radius_constant / self.reg_lambda(True)
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
return self.reg_lambda(True)
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight)
|
||||||
|
return self._sample_weight * max_class_weight / \
|
||||||
|
(self.delta * tf.Variable(initial_value=2, trainable=False)) + \
|
||||||
|
self.reg_lambda(True)
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
# if class_weight is provided,
|
||||||
|
# it should be a vector of the same size of number of classes
|
||||||
|
max_class_weight = self.max_class_weight(class_weight)
|
||||||
|
lc = self._sample_weight * max_class_weight + \
|
||||||
|
self.reg_lambda(True) * self.radius()
|
||||||
|
return lc
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
|
||||||
|
"""
|
||||||
|
Strong Convex version of BinaryCrossentropy loss using l2 weight
|
||||||
|
regularization.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reg_lambda: float,
|
||||||
|
c: float,
|
||||||
|
radius_constant: float,
|
||||||
|
from_logits: bool = True,
|
||||||
|
label_smoothing: float = 0,
|
||||||
|
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
name: str = 'binarycrossentropy',
|
||||||
|
dtype=tf.float32):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
reg_lambda: Weight regularization constant
|
||||||
|
c: Additional constant for strongly convex convergence. Acts
|
||||||
|
as a global weight.
|
||||||
|
radius_constant: constant defining the length of the radius
|
||||||
|
reduction: reduction type to use. See super class
|
||||||
|
label_smoothing: amount of smoothing to perform on labels
|
||||||
|
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||||
|
name: Name of the loss instance
|
||||||
|
dtype: tf datatype to use for tensor conversions.
|
||||||
|
"""
|
||||||
|
super(BinaryCrossentropy, self).__init__(reg_lambda,
|
||||||
|
c,
|
||||||
|
radius_constant,
|
||||||
|
reduction=reduction,
|
||||||
|
name=name,
|
||||||
|
from_logits=from_logits,
|
||||||
|
label_smoothing=label_smoothing,
|
||||||
|
dtype=dtype
|
||||||
|
)
|
||||||
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Compute loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values.
|
||||||
|
y_pred: The predicted values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||||
|
labels=y_true,
|
||||||
|
logits=y_pred
|
||||||
|
)
|
||||||
|
loss = loss * self._sample_weight
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
return self.radius_constant / self.reg_lambda(True)
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
return self.reg_lambda(True)
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight)
|
||||||
|
return self._sample_weight * max_class_weight + self.reg_lambda(True)
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
"""See super class.
|
||||||
|
"""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight)
|
||||||
|
return self._sample_weight * max_class_weight + \
|
||||||
|
self.reg_lambda(True) * self.radius()
|
3
privacy/bolton/loss_test.py
Normal file
3
privacy/bolton/loss_test.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
402
privacy/bolton/model.py
Normal file
402
privacy/bolton/model.py
Normal file
|
@ -0,0 +1,402 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Bolton model for bolton method of differentially private ML"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras.models import Model
|
||||||
|
from tensorflow.python.keras import optimizers
|
||||||
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
|
from tensorflow.python.framework import ops as _ops
|
||||||
|
from privacy.bolton.loss import StrongConvexLoss
|
||||||
|
from privacy.bolton.optimizer import Private
|
||||||
|
|
||||||
|
|
||||||
|
class Bolton(Model):
|
||||||
|
"""
|
||||||
|
Bolton episilon-delta model
|
||||||
|
Uses 4 key steps to achieve privacy guarantees:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Use a strongly convex loss function (see compile)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
n_classes,
|
||||||
|
epsilon,
|
||||||
|
noise_distribution='laplace',
|
||||||
|
weights_initializer=tf.initializers.GlorotUniform(),
|
||||||
|
seed=1,
|
||||||
|
dtype=tf.float32
|
||||||
|
):
|
||||||
|
""" private constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes: number of output classes to predict.
|
||||||
|
epsilon: level of privacy guarantee
|
||||||
|
noise_distribution: distribution to pull weight perturbations from
|
||||||
|
weights_initializer: initializer for weights
|
||||||
|
seed: random seed to use
|
||||||
|
dtype: data type to use for tensors
|
||||||
|
"""
|
||||||
|
|
||||||
|
class MyCustomCallback(tf.keras.callbacks.Callback):
|
||||||
|
"""Custom callback for bolton training requirements.
|
||||||
|
Implements steps (see Bolton class):
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
"""
|
||||||
|
def on_train_batch_end(self, batch, logs=None):
|
||||||
|
loss = self.model.loss
|
||||||
|
self.model.optimizer.limit_learning_rate(
|
||||||
|
self.model.run_eagerly,
|
||||||
|
loss.beta(self.model.class_weight),
|
||||||
|
loss.gamma()
|
||||||
|
)
|
||||||
|
self.model._project_weights_to_r(loss.radius(), False)
|
||||||
|
|
||||||
|
def on_train_end(self, logs=None):
|
||||||
|
loss = self.model.loss
|
||||||
|
self.model._project_weights_to_r(loss.radius(), True)
|
||||||
|
|
||||||
|
super(Bolton, self).__init__(name='bolton', dynamic=False)
|
||||||
|
self.n_classes = n_classes
|
||||||
|
self.output_layer = tf.keras.layers.Dense(
|
||||||
|
self.n_classes,
|
||||||
|
kernel_regularizer=tf.keras.regularizers.l2(),
|
||||||
|
kernel_initializer=weights_initializer,
|
||||||
|
)
|
||||||
|
# if we do regularization here, we require the user to re-instantiate
|
||||||
|
# the model each time they want to
|
||||||
|
# change lambda, unless we standardize modifying it later at .compile
|
||||||
|
self.force = False
|
||||||
|
self.noise_distribution = noise_distribution
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.seed = seed
|
||||||
|
self.__in_fit = False
|
||||||
|
self._callback = MyCustomCallback()
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
"""Forward pass of network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: inputs to neural network
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.output_layer(inputs)
|
||||||
|
|
||||||
|
def compile(self,
|
||||||
|
optimizer='SGD',
|
||||||
|
loss=None,
|
||||||
|
metrics=None,
|
||||||
|
loss_weights=None,
|
||||||
|
sample_weight_mode=None,
|
||||||
|
weighted_metrics=None,
|
||||||
|
target_tensors=None,
|
||||||
|
distribute=None,
|
||||||
|
**kwargs):
|
||||||
|
"""See super class. Default optimizer used in Bolton method is SGD.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not isinstance(loss, StrongConvexLoss):
|
||||||
|
raise ValueError("Loss must be subclassed from StrongConvexLoss")
|
||||||
|
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda()
|
||||||
|
if not isinstance(optimizer, Private):
|
||||||
|
optimizer = optimizers.get(optimizer)
|
||||||
|
if isinstance(self.optimizer, trackable.Trackable):
|
||||||
|
self._track_trackable(
|
||||||
|
self.optimizer, name='optimizer', overwrite=True
|
||||||
|
)
|
||||||
|
optimizer = Private(optimizer)
|
||||||
|
|
||||||
|
super(Bolton, self).compile(optimizer,
|
||||||
|
loss=loss,
|
||||||
|
metrics=metrics,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
sample_weight_mode=sample_weight_mode,
|
||||||
|
weighted_metrics=weighted_metrics,
|
||||||
|
target_tensors=target_tensors,
|
||||||
|
distribute=distribute,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def _post_fit(self, x, n_samples):
|
||||||
|
"""Implements 1-time weight changes needed for Bolton method.
|
||||||
|
In this case, specifically implements the noise addition
|
||||||
|
assuming a strongly convex function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: inputs
|
||||||
|
n_samples: number of samples in the inputs. In case the number
|
||||||
|
cannot be readily determined by inspecting x.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if n_samples is not None:
|
||||||
|
data_size = n_samples
|
||||||
|
elif hasattr(x, 'shape'):
|
||||||
|
data_size = x.shape[0]
|
||||||
|
elif hasattr(x, "__len__"):
|
||||||
|
data_size = len(x)
|
||||||
|
else:
|
||||||
|
if n_samples is None:
|
||||||
|
raise ValueError("Unable to detect the number of training "
|
||||||
|
"samples and n_smaples was None. "
|
||||||
|
"either pass a dataset with a .shape or "
|
||||||
|
"__len__ attribute or explicitly pass the "
|
||||||
|
"number of samples as n_smaples.")
|
||||||
|
data_size = n_samples
|
||||||
|
|
||||||
|
for layer in self._layers:
|
||||||
|
layer.kernel = layer.kernel + self._get_noise(
|
||||||
|
self.noise_distribution,
|
||||||
|
data_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
x=None,
|
||||||
|
y=None,
|
||||||
|
batch_size=None,
|
||||||
|
epochs=1,
|
||||||
|
verbose=1,
|
||||||
|
callbacks=None,
|
||||||
|
validation_split=0.0,
|
||||||
|
validation_data=None,
|
||||||
|
shuffle=True,
|
||||||
|
class_weight=None,
|
||||||
|
sample_weight=None,
|
||||||
|
initial_epoch=0,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
validation_steps=None,
|
||||||
|
validation_freq=1,
|
||||||
|
max_queue_size=10,
|
||||||
|
workers=1,
|
||||||
|
use_multiprocessing=False,
|
||||||
|
n_samples=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
|
||||||
|
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
|
||||||
|
Requirements are as follows:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Use a strongly convex loss function (see compile)
|
||||||
|
See super implementation for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: the number of individual samples in x.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.__in_fit = True
|
||||||
|
cb = [self._callback]
|
||||||
|
if callbacks is not None:
|
||||||
|
cb.extend(callbacks)
|
||||||
|
callbacks = cb
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight = self.calculate_class_weights(class_weight)
|
||||||
|
self.class_weight = class_weight
|
||||||
|
out = super(Bolton, self).fit(x=x,
|
||||||
|
y=y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
epochs=epochs,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
validation_split=validation_split,
|
||||||
|
validation_data=validation_data,
|
||||||
|
shuffle=shuffle,
|
||||||
|
class_weight=class_weight,
|
||||||
|
sample_weight=sample_weight,
|
||||||
|
initial_epoch=initial_epoch,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
validation_steps=validation_steps,
|
||||||
|
validation_freq=validation_freq,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
self._post_fit(x, n_samples)
|
||||||
|
self.__in_fit = False
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fit_generator(self,
|
||||||
|
generator,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
epochs=1,
|
||||||
|
verbose=1,
|
||||||
|
callbacks=None,
|
||||||
|
validation_data=None,
|
||||||
|
validation_steps=None,
|
||||||
|
validation_freq=1,
|
||||||
|
class_weight=None,
|
||||||
|
max_queue_size=10,
|
||||||
|
workers=1,
|
||||||
|
use_multiprocessing=False,
|
||||||
|
shuffle=True,
|
||||||
|
initial_epoch=0,
|
||||||
|
n_samples=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
This method is the same as fit except for when the passed dataset
|
||||||
|
is a generator. See super method and fit for more details.
|
||||||
|
Args:
|
||||||
|
n_samples: number of individual samples in x
|
||||||
|
|
||||||
|
"""
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight = self.calculate_class_weights(class_weight)
|
||||||
|
self.class_weight = class_weight
|
||||||
|
out = super(Bolton, self).fit_generator(
|
||||||
|
generator,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
epochs=epochs,
|
||||||
|
verbose=verbose,
|
||||||
|
callbacks=callbacks,
|
||||||
|
validation_data=validation_data,
|
||||||
|
validation_steps=validation_steps,
|
||||||
|
validation_freq=validation_freq,
|
||||||
|
class_weight=class_weight,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
workers=workers,
|
||||||
|
use_multiprocessing=use_multiprocessing,
|
||||||
|
shuffle=shuffle,
|
||||||
|
initial_epoch=initial_epoch
|
||||||
|
)
|
||||||
|
if not self.__in_fit:
|
||||||
|
self._post_fit(generator, n_samples)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def calculate_class_weights(self,
|
||||||
|
class_weights=None,
|
||||||
|
class_counts=None,
|
||||||
|
num_classes=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Calculates class weighting to be used in training. Can be on
|
||||||
|
Args:
|
||||||
|
class_weights: str specifying type, array giving weights, or None.
|
||||||
|
class_counts: If class_weights is not None, then the number of
|
||||||
|
samples for each class
|
||||||
|
num_classes: If class_weights is not None, then the number of
|
||||||
|
classes.
|
||||||
|
Returns: class_weights as 1D tensor, to be passed to model's fit method.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Value checking
|
||||||
|
class_keys = ['balanced']
|
||||||
|
is_string = False
|
||||||
|
if isinstance(class_weights, str):
|
||||||
|
is_string = True
|
||||||
|
if class_weights not in class_keys:
|
||||||
|
raise ValueError("Detected string class_weights with "
|
||||||
|
"value: {0}, which is not one of {1}."
|
||||||
|
"Please select a valid class_weight type"
|
||||||
|
"or pass an array".format(class_weights,
|
||||||
|
class_keys))
|
||||||
|
if class_counts is None:
|
||||||
|
raise ValueError("Class counts must be provided if using"
|
||||||
|
"class_weights=%s" % class_weights)
|
||||||
|
if num_classes is None:
|
||||||
|
raise ValueError("Class counts must be provided if using"
|
||||||
|
"class_weights=%s" % class_weights)
|
||||||
|
elif class_weights is not None:
|
||||||
|
if num_classes is None:
|
||||||
|
raise ValueError("You must pass a value for num_classes if"
|
||||||
|
"creating an array of class_weights")
|
||||||
|
# performing class weight calculation
|
||||||
|
if class_weights is None:
|
||||||
|
class_weights = 1
|
||||||
|
elif is_string and class_weights == 'balanced':
|
||||||
|
num_samples = sum(class_counts)
|
||||||
|
class_weights = tf.Variable(
|
||||||
|
num_samples / (num_classes * class_counts),
|
||||||
|
dtype=self._dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
class_weights = _ops.convert_to_tensor_v2(class_weights)
|
||||||
|
if len(class_weights.shape) != 1:
|
||||||
|
raise ValueError("Detected class_weights shape: {0} instead of "
|
||||||
|
"1D array".format(class_weights.shape))
|
||||||
|
if class_weights.shape[0] != num_classes:
|
||||||
|
raise ValueError(
|
||||||
|
"Detected array length: {0} instead of: {1}".format(
|
||||||
|
class_weights.shape[0],
|
||||||
|
num_classes
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return class_weights
|
||||||
|
|
||||||
|
def _project_weights_to_r(self, r, force=False):
|
||||||
|
"""helper method to normalize the weights to the R-ball.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r: radius of "R-Ball". Scalar to normalize to.
|
||||||
|
force: True to normalize regardless of previous weight values.
|
||||||
|
False to check if weights > R-ball and only normalize then.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
for layer in self._layers:
|
||||||
|
weight_norm = tf.norm(layer.kernel, axis=0)
|
||||||
|
if force:
|
||||||
|
layer.kernel = layer.kernel / (weight_norm / r)
|
||||||
|
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
|
||||||
|
layer.kernel = layer.kernel / (weight_norm / r)
|
||||||
|
|
||||||
|
def _get_noise(self, distribution, data_size):
|
||||||
|
"""Sample noise to be added to weights for privacy guarantee
|
||||||
|
|
||||||
|
Args:
|
||||||
|
distribution: the distribution type to pull noise from
|
||||||
|
data_size: the number of samples
|
||||||
|
|
||||||
|
Returns: noise in shape of layer's weights to be added to the weights.
|
||||||
|
|
||||||
|
"""
|
||||||
|
distribution = distribution.lower()
|
||||||
|
input_dim = self._layers[0].kernel.numpy().shape[0]
|
||||||
|
loss = self.loss
|
||||||
|
if distribution == 'laplace':
|
||||||
|
per_class_epsilon = self.epsilon / (self.n_classes)
|
||||||
|
l2_sensitivity = (2 *
|
||||||
|
loss.lipchitz_constant(self.class_weight)) / \
|
||||||
|
(loss.gamma() * data_size)
|
||||||
|
unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
|
||||||
|
mean=0,
|
||||||
|
seed=1,
|
||||||
|
stddev=1.0,
|
||||||
|
dtype=self._dtype)
|
||||||
|
unit_vector = unit_vector / tf.math.sqrt(
|
||||||
|
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
beta = l2_sensitivity / per_class_epsilon
|
||||||
|
alpha = input_dim # input_dim
|
||||||
|
gamma = tf.random.gamma([self.n_classes],
|
||||||
|
alpha,
|
||||||
|
beta=1 / beta,
|
||||||
|
seed=1,
|
||||||
|
dtype=self._dtype)
|
||||||
|
return unit_vector * gamma
|
||||||
|
raise NotImplementedError("distribution: {0} is not "
|
||||||
|
"currently supported".format(distribution))
|
3
privacy/bolton/model_test.py
Normal file
3
privacy/bolton/model_test.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
173
privacy/bolton/optimizer.py
Normal file
173
privacy/bolton/optimizer.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright 2018, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Private Optimizer for bolton method"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
|
|
||||||
|
_private_attributes = ['_internal_optimizer', 'dtype']
|
||||||
|
|
||||||
|
|
||||||
|
class Private(optimizer_v2.OptimizerV2):
|
||||||
|
"""
|
||||||
|
Private optimizer wraps another tf optimizer to be used
|
||||||
|
as the visible optimizer to the tf model. No matter the optimizer
|
||||||
|
passed, "Private" enables the bolton model to control the learning rate
|
||||||
|
based on the strongly convex loss.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: optimizer_v2.OptimizerV2,
|
||||||
|
dtype=tf.float32
|
||||||
|
):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Optimizer_v2 or subclass to be used as the optimizer
|
||||||
|
(wrapped).
|
||||||
|
"""
|
||||||
|
self._internal_optimizer = optimizer
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.get_config()
|
||||||
|
|
||||||
|
def limit_learning_rate(self, is_eager, beta, gamma):
|
||||||
|
"""Implements learning rate limitation that is required by the bolton
|
||||||
|
method for sensitivity bounding of the strongly convex function.
|
||||||
|
Sets the learning rate to the min(1/beta, 1/(gamma*t))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_eager: Whether the model is running in eager mode
|
||||||
|
beta: loss function beta-smoothness
|
||||||
|
gamma: loss function gamma-strongly convex
|
||||||
|
|
||||||
|
Returns: None
|
||||||
|
|
||||||
|
"""
|
||||||
|
numerator = tf.Variable(initial_value=1, dtype=self.dtype)
|
||||||
|
t = tf.cast(self._iterations, self.dtype)
|
||||||
|
# will exist on the internal optimizer
|
||||||
|
pred = numerator / beta < numerator / (gamma * t)
|
||||||
|
if is_eager: # check eagerly
|
||||||
|
if pred:
|
||||||
|
self.learning_rate = numerator / beta
|
||||||
|
else:
|
||||||
|
self.learning_rate = numerator / (gamma * t)
|
||||||
|
else:
|
||||||
|
if pred:
|
||||||
|
self.learning_rate = numerator / beta
|
||||||
|
else:
|
||||||
|
self.learning_rate = numerator / (gamma * t)
|
||||||
|
|
||||||
|
def from_config(self, config, custom_objects=None):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.from_config(
|
||||||
|
config,
|
||||||
|
custom_objects=custom_objects
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""return _internal_optimizer off self instance, and everything else
|
||||||
|
from the _internal_optimizer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name:
|
||||||
|
|
||||||
|
Returns: attribute from Private if specified to come from self, else
|
||||||
|
from _internal_optimizer.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if name in _private_attributes:
|
||||||
|
return getattr(self, name)
|
||||||
|
optim = object.__getattribute__(self, '_internal_optimizer')
|
||||||
|
return object.__getattribute__(optim, name)
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
""" Set attribute to self instance if its the internal optimizer.
|
||||||
|
Reroute everything else to the _internal_optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: attribute name
|
||||||
|
value: attribute value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if key in _private_attributes:
|
||||||
|
object.__setattr__(self, key, value)
|
||||||
|
else:
|
||||||
|
setattr(self._internal_optimizer, key, value)
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, grad, handle):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer._resource_apply_dense(grad, handle)
|
||||||
|
|
||||||
|
def _resource_apply_sparse(self, grad, handle, indices):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer._resource_apply_sparse(
|
||||||
|
grad,
|
||||||
|
handle,
|
||||||
|
indices
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_updates(self, loss, params):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.get_updates(loss, params)
|
||||||
|
|
||||||
|
def apply_gradients(self, grads_and_vars, name: str = None):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.apply_gradients(
|
||||||
|
grads_and_vars,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
|
|
||||||
|
def minimize(self,
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
grad_loss: bool = None,
|
||||||
|
name: str = None
|
||||||
|
):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.minimize(
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
grad_loss,
|
||||||
|
name
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer._compute_gradients(
|
||||||
|
loss,
|
||||||
|
var_list,
|
||||||
|
grad_loss=grad_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_gradients(self, loss, params):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||||
|
"""
|
||||||
|
return self._internal_optimizer.get_gradients(loss, params)
|
9
privacy/bolton/optimizer_test.py
Normal file
9
privacy/bolton/optimizer_test.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from privacy.bolton import model
|
||||||
|
|
Loading…
Reference in a new issue