Working bolton model without unit tests.
-- update to include pull request changes changes include: parameter renaming, changing to mixin, moving model to compile, additional tests, fixing huber loss
This commit is contained in:
parent
5f46927747
commit
751eaead54
7 changed files with 1472 additions and 169 deletions
|
@ -10,5 +10,5 @@ if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
|
|||
pass
|
||||
else:
|
||||
from privacy.bolton.model import Bolton
|
||||
from privacy.bolton.loss import Huber
|
||||
from privacy.bolton.loss import BinaryCrossentropy
|
||||
from privacy.bolton.loss import StrongConvexHuber
|
||||
from privacy.bolton.loss import StrongConvexBinaryCrossentropy
|
|
@ -20,56 +20,33 @@ import tensorflow as tf
|
|||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.keras.regularizers import L1L2
|
||||
|
||||
|
||||
class StrongConvexLoss(losses.Loss):
|
||||
class StrongConvexMixin:
|
||||
"""
|
||||
Strong Convex Loss base class for any loss function that will be used with
|
||||
Strong Convex Mixin base class for any loss function that will be used with
|
||||
Bolton model. Subclasses must be strongly convex and implement the
|
||||
associated constants. They must also conform to the requirements of tf losses
|
||||
(see super class)
|
||||
(see super class).
|
||||
|
||||
For more details on the strong convexity requirements, see:
|
||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||
Descent-based Analytics by Xi Wu et. al.
|
||||
"""
|
||||
def __init__(self,
|
||||
reg_lambda: float,
|
||||
c: float,
|
||||
radius_constant: float = 1,
|
||||
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
name: str = None,
|
||||
dtype=tf.float32,
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
reg_lambda: Weight regularization constant
|
||||
c: Additional constant for strongly convex convergence. Acts
|
||||
as a global weight.
|
||||
radius_constant: constant defining the length of the radius
|
||||
reduction: reduction type to use. See super class
|
||||
name: Name of the loss instance
|
||||
dtype: tf datatype to use for tensor conversions.
|
||||
"""
|
||||
super(StrongConvexLoss, self).__init__(reduction=reduction,
|
||||
name=name,
|
||||
**kwargs)
|
||||
self._sample_weight = tf.Variable(initial_value=c,
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
self._reg_lambda = reg_lambda
|
||||
self.radius_constant = tf.Variable(initial_value=radius_constant,
|
||||
trainable=False,
|
||||
dtype=tf.float32)
|
||||
self.dtype = dtype
|
||||
|
||||
def radius(self):
|
||||
"""Radius of R-Ball (value to normalize weights to after each batch)
|
||||
"""Radius, R, of the hypothesis space W.
|
||||
W is a convex set that forms the hypothesis space.
|
||||
|
||||
Returns: radius
|
||||
Returns: R
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
|
||||
"function: %s" % str(self.__class__.__name__))
|
||||
|
||||
def gamma(self):
|
||||
""" Gamma strongly convex
|
||||
""" Strongly convexity, gamma
|
||||
|
||||
Returns: gamma
|
||||
|
||||
|
@ -78,7 +55,7 @@ class StrongConvexLoss(losses.Loss):
|
|||
"function: %s" % str(self.__class__.__name__))
|
||||
|
||||
def beta(self, class_weight):
|
||||
"""Beta smoothess
|
||||
"""Smoothness, beta
|
||||
|
||||
Args:
|
||||
class_weight: the class weights used.
|
||||
|
@ -90,7 +67,7 @@ class StrongConvexLoss(losses.Loss):
|
|||
"function: %s" % str(self.__class__.__name__))
|
||||
|
||||
def lipchitz_constant(self, class_weight):
|
||||
""" L lipchitz continuous
|
||||
"""Lipchitz constant, L
|
||||
|
||||
Args:
|
||||
class_weight: class weights used
|
||||
|
@ -102,43 +79,46 @@ class StrongConvexLoss(losses.Loss):
|
|||
"StrongConvex Loss"
|
||||
"function: %s" % str(self.__class__.__name__))
|
||||
|
||||
def reg_lambda(self, convert_to_tensor: bool = False):
|
||||
""" returns the lambda weight regularization constant, as a tensor if
|
||||
desired
|
||||
def kernel_regularizer(self):
|
||||
"""returns the kernel_regularizer to be used. Any subclass should override
|
||||
this method if they want a kernel_regularizer (if required for
|
||||
the loss function to be StronglyConvex
|
||||
|
||||
:return: None or kernel_regularizer layer
|
||||
"""
|
||||
return None
|
||||
|
||||
def max_class_weight(self, class_weight, dtype):
|
||||
"""the maximum weighting in class weights (max value) as a scalar tensor
|
||||
|
||||
Args:
|
||||
convert_to_tensor: True to convert to tensor, False to leave as
|
||||
python numeric.
|
||||
class_weight: class weights used
|
||||
dtype: the data type for tensor conversions.
|
||||
|
||||
Returns: reg_lambda
|
||||
Returns: maximum class weighting as tensor scalar
|
||||
|
||||
"""
|
||||
if convert_to_tensor:
|
||||
return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype)
|
||||
return self._reg_lambda
|
||||
|
||||
def max_class_weight(self, class_weight):
|
||||
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype)
|
||||
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
|
||||
return tf.math.reduce_max(class_weight)
|
||||
|
||||
|
||||
class Huber(StrongConvexLoss, losses.Huber):
|
||||
"""Strong Convex version of huber loss using l2 weight regularization.
|
||||
class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
||||
"""Strong Convex version of Huber loss using l2 weight regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reg_lambda: float,
|
||||
c: float,
|
||||
C: float,
|
||||
radius_constant: float,
|
||||
delta: float,
|
||||
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
name: str = 'huber',
|
||||
dtype=tf.float32):
|
||||
"""Constructor. Passes arguments to StrongConvexLoss and Huber Loss.
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
reg_lambda: Weight regularization constant
|
||||
c: Additional constant for strongly convex convergence. Acts
|
||||
as a global weight.
|
||||
C: Penalty parameter C of the loss term
|
||||
radius_constant: constant defining the length of the radius
|
||||
delta: delta value in huber loss. When to switch from quadratic to
|
||||
absolute deviation.
|
||||
|
@ -149,15 +129,22 @@ class Huber(StrongConvexLoss, losses.Huber):
|
|||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
# self.delta = tf.Variable(initial_value=delta, trainable=False)
|
||||
super(Huber, self).__init__(
|
||||
reg_lambda,
|
||||
c,
|
||||
radius_constant,
|
||||
if C <= 0:
|
||||
raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||
if reg_lambda <= 0:
|
||||
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||
if radius_constant <= 0:
|
||||
raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
radius_constant
|
||||
))
|
||||
self.C = C
|
||||
self.radius_constant = radius_constant
|
||||
self.dtype = dtype
|
||||
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
super(StrongConvexHuber, self).__init__(
|
||||
delta=delta,
|
||||
name=name,
|
||||
reduction=reduction,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
|
@ -170,46 +157,73 @@ class Huber(StrongConvexLoss, losses.Huber):
|
|||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \
|
||||
self._sample_weight
|
||||
# return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
|
||||
h = self._fn_kwargs['delta']
|
||||
z = y_pred * y_true
|
||||
one = tf.constant(1, dtype=self.dtype)
|
||||
four = tf.constant(4, dtype=self.dtype)
|
||||
|
||||
if z > one + h:
|
||||
return z - z
|
||||
elif tf.math.abs(one - z) <= h:
|
||||
return one / (four * h) * tf.math.pow(one + h - z, 2)
|
||||
elif z < one - h:
|
||||
return one - z
|
||||
else:
|
||||
raise ValueError('')
|
||||
|
||||
def radius(self):
|
||||
"""See super class.
|
||||
"""
|
||||
return self.radius_constant / self.reg_lambda(True)
|
||||
return self.radius_constant / self.reg_lambda
|
||||
|
||||
def gamma(self):
|
||||
"""See super class.
|
||||
"""
|
||||
return self.reg_lambda(True)
|
||||
return self.reg_lambda
|
||||
|
||||
def beta(self, class_weight):
|
||||
"""See super class.
|
||||
"""
|
||||
max_class_weight = self.max_class_weight(class_weight)
|
||||
return self._sample_weight * max_class_weight / \
|
||||
(self.delta * tf.Variable(initial_value=2, trainable=False)) + \
|
||||
self.reg_lambda(True)
|
||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'],
|
||||
dtype=self.dtype
|
||||
)
|
||||
return self.C * max_class_weight / (delta *
|
||||
tf.constant(2, dtype=self.dtype)) + \
|
||||
self.reg_lambda
|
||||
|
||||
def lipchitz_constant(self, class_weight):
|
||||
"""See super class.
|
||||
"""
|
||||
# if class_weight is provided,
|
||||
# it should be a vector of the same size of number of classes
|
||||
max_class_weight = self.max_class_weight(class_weight)
|
||||
lc = self._sample_weight * max_class_weight + \
|
||||
self.reg_lambda(True) * self.radius()
|
||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
lc = self.C * max_class_weight + \
|
||||
self.reg_lambda * self.radius()
|
||||
return lc
|
||||
|
||||
def kernel_regularizer(self):
|
||||
"""
|
||||
l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||
this loss function to be strongly convex.
|
||||
:return:
|
||||
"""
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
|
||||
class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
|
||||
|
||||
class StrongConvexBinaryCrossentropy(
|
||||
losses.BinaryCrossentropy,
|
||||
StrongConvexMixin
|
||||
):
|
||||
"""
|
||||
Strong Convex version of BinaryCrossentropy loss using l2 weight
|
||||
regularization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reg_lambda: float,
|
||||
c: float,
|
||||
C: float,
|
||||
radius_constant: float,
|
||||
from_logits: bool = True,
|
||||
label_smoothing: float = 0,
|
||||
|
@ -219,8 +233,7 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
|
|||
"""
|
||||
Args:
|
||||
reg_lambda: Weight regularization constant
|
||||
c: Additional constant for strongly convex convergence. Acts
|
||||
as a global weight.
|
||||
C: Penalty parameter C of the loss term
|
||||
radius_constant: constant defining the length of the radius
|
||||
reduction: reduction type to use. See super class
|
||||
label_smoothing: amount of smoothing to perform on labels
|
||||
|
@ -228,15 +241,23 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
|
|||
name: Name of the loss instance
|
||||
dtype: tf datatype to use for tensor conversions.
|
||||
"""
|
||||
super(BinaryCrossentropy, self).__init__(reg_lambda,
|
||||
c,
|
||||
radius_constant,
|
||||
reduction=reduction,
|
||||
name=name,
|
||||
from_logits=from_logits,
|
||||
label_smoothing=label_smoothing,
|
||||
dtype=dtype
|
||||
)
|
||||
if reg_lambda <= 0:
|
||||
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||
if C <= 0:
|
||||
raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||
if radius_constant <= 0:
|
||||
raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
radius_constant
|
||||
))
|
||||
self.dtype = dtype
|
||||
self.C = C
|
||||
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
super(StrongConvexBinaryCrossentropy, self).__init__(
|
||||
reduction=reduction,
|
||||
name=name,
|
||||
from_logits=from_logits,
|
||||
label_smoothing=label_smoothing,
|
||||
)
|
||||
self.radius_constant = radius_constant
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
|
@ -249,32 +270,319 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
|
|||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=y_true,
|
||||
logits=y_pred
|
||||
)
|
||||
loss = loss * self._sample_weight
|
||||
# loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
# labels=y_true,
|
||||
# logits=y_pred
|
||||
# )
|
||||
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
|
||||
loss = loss * self.C
|
||||
return loss
|
||||
|
||||
def radius(self):
|
||||
"""See super class.
|
||||
"""
|
||||
return self.radius_constant / self.reg_lambda(True)
|
||||
return self.radius_constant / self.reg_lambda
|
||||
|
||||
def gamma(self):
|
||||
"""See super class.
|
||||
"""
|
||||
return self.reg_lambda(True)
|
||||
return self.reg_lambda
|
||||
|
||||
def beta(self, class_weight):
|
||||
"""See super class.
|
||||
"""
|
||||
max_class_weight = self.max_class_weight(class_weight)
|
||||
return self._sample_weight * max_class_weight + self.reg_lambda(True)
|
||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
return self.C * max_class_weight + self.reg_lambda
|
||||
|
||||
def lipchitz_constant(self, class_weight):
|
||||
"""See super class.
|
||||
"""
|
||||
max_class_weight = self.max_class_weight(class_weight)
|
||||
return self._sample_weight * max_class_weight + \
|
||||
self.reg_lambda(True) * self.radius()
|
||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||
|
||||
def kernel_regularizer(self):
|
||||
"""
|
||||
l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||
this loss function to be strongly convex.
|
||||
:return:
|
||||
"""
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
|
||||
|
||||
# class StrongConvexSparseCategoricalCrossentropy(
|
||||
# losses.CategoricalCrossentropy,
|
||||
# StrongConvexMixin
|
||||
# ):
|
||||
# """
|
||||
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
|
||||
# regularization.
|
||||
# """
|
||||
#
|
||||
# def __init__(self,
|
||||
# reg_lambda: float,
|
||||
# C: float,
|
||||
# radius_constant: float,
|
||||
# from_logits: bool = True,
|
||||
# label_smoothing: float = 0,
|
||||
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
# name: str = 'binarycrossentropy',
|
||||
# dtype=tf.float32):
|
||||
# """
|
||||
# Args:
|
||||
# reg_lambda: Weight regularization constant
|
||||
# C: Penalty parameter C of the loss term
|
||||
# radius_constant: constant defining the length of the radius
|
||||
# reduction: reduction type to use. See super class
|
||||
# label_smoothing: amount of smoothing to perform on labels
|
||||
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||
# name: Name of the loss instance
|
||||
# dtype: tf datatype to use for tensor conversions.
|
||||
# """
|
||||
# if reg_lambda <= 0:
|
||||
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||
# if C <= 0:
|
||||
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||
# if radius_constant <= 0:
|
||||
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
# radius_constant
|
||||
# ))
|
||||
#
|
||||
# self.C = C
|
||||
# self.dtype = dtype
|
||||
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
# super(StrongConvexSparseCategoricalCrossentropy, self).__init__(
|
||||
# reduction=reduction,
|
||||
# name=name,
|
||||
# from_logits=from_logits,
|
||||
# label_smoothing=label_smoothing,
|
||||
# )
|
||||
# self.radius_constant = radius_constant
|
||||
#
|
||||
# def call(self, y_true, y_pred):
|
||||
# """Compute loss
|
||||
#
|
||||
# Args:
|
||||
# y_true: Ground truth values.
|
||||
# y_pred: The predicted values.
|
||||
#
|
||||
# Returns:
|
||||
# Loss values per sample.
|
||||
# """
|
||||
# loss = super()
|
||||
# loss = loss * self.C
|
||||
# return loss
|
||||
#
|
||||
# def radius(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.radius_constant / self.reg_lambda
|
||||
#
|
||||
# def gamma(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.reg_lambda
|
||||
#
|
||||
# def beta(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda
|
||||
#
|
||||
# def lipchitz_constant(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||
#
|
||||
# def kernel_regularizer(self):
|
||||
# """
|
||||
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||
# this loss function to be strongly convex.
|
||||
# :return:
|
||||
# """
|
||||
# return L1L2(l2=self.reg_lambda)
|
||||
#
|
||||
# class StrongConvexSparseCategoricalCrossentropy(
|
||||
# losses.SparseCategoricalCrossentropy,
|
||||
# StrongConvexMixin
|
||||
# ):
|
||||
# """
|
||||
# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight
|
||||
# regularization.
|
||||
# """
|
||||
#
|
||||
# def __init__(self,
|
||||
# reg_lambda: float,
|
||||
# C: float,
|
||||
# radius_constant: float,
|
||||
# from_logits: bool = True,
|
||||
# label_smoothing: float = 0,
|
||||
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
# name: str = 'binarycrossentropy',
|
||||
# dtype=tf.float32):
|
||||
# """
|
||||
# Args:
|
||||
# reg_lambda: Weight regularization constant
|
||||
# C: Penalty parameter C of the loss term
|
||||
# radius_constant: constant defining the length of the radius
|
||||
# reduction: reduction type to use. See super class
|
||||
# label_smoothing: amount of smoothing to perform on labels
|
||||
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||
# name: Name of the loss instance
|
||||
# dtype: tf datatype to use for tensor conversions.
|
||||
# """
|
||||
# if reg_lambda <= 0:
|
||||
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||
# if C <= 0:
|
||||
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||
# if radius_constant <= 0:
|
||||
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
# radius_constant
|
||||
# ))
|
||||
#
|
||||
# self.C = C
|
||||
# self.dtype = dtype
|
||||
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
# super(StrongConvexHuber, self).__init__(reduction=reduction,
|
||||
# name=name,
|
||||
# from_logits=from_logits,
|
||||
# label_smoothing=label_smoothing,
|
||||
# )
|
||||
# self.radius_constant = radius_constant
|
||||
#
|
||||
# def call(self, y_true, y_pred):
|
||||
# """Compute loss
|
||||
#
|
||||
# Args:
|
||||
# y_true: Ground truth values.
|
||||
# y_pred: The predicted values.
|
||||
#
|
||||
# Returns:
|
||||
# Loss values per sample.
|
||||
# """
|
||||
# loss = super()
|
||||
# loss = loss * self.C
|
||||
# return loss
|
||||
#
|
||||
# def radius(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.radius_constant / self.reg_lambda
|
||||
#
|
||||
# def gamma(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.reg_lambda
|
||||
#
|
||||
# def beta(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda
|
||||
#
|
||||
# def lipchitz_constant(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||
#
|
||||
# def kernel_regularizer(self):
|
||||
# """
|
||||
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||
# this loss function to be strongly convex.
|
||||
# :return:
|
||||
# """
|
||||
# return L1L2(l2=self.reg_lambda)
|
||||
#
|
||||
#
|
||||
# class StrongConvexCategoricalCrossentropy(
|
||||
# losses.CategoricalCrossentropy,
|
||||
# StrongConvexMixin
|
||||
# ):
|
||||
# """
|
||||
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
|
||||
# regularization.
|
||||
# """
|
||||
#
|
||||
# def __init__(self,
|
||||
# reg_lambda: float,
|
||||
# C: float,
|
||||
# radius_constant: float,
|
||||
# from_logits: bool = True,
|
||||
# label_smoothing: float = 0,
|
||||
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
# name: str = 'binarycrossentropy',
|
||||
# dtype=tf.float32):
|
||||
# """
|
||||
# Args:
|
||||
# reg_lambda: Weight regularization constant
|
||||
# C: Penalty parameter C of the loss term
|
||||
# radius_constant: constant defining the length of the radius
|
||||
# reduction: reduction type to use. See super class
|
||||
# label_smoothing: amount of smoothing to perform on labels
|
||||
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||
# name: Name of the loss instance
|
||||
# dtype: tf datatype to use for tensor conversions.
|
||||
# """
|
||||
# if reg_lambda <= 0:
|
||||
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||
# if C <= 0:
|
||||
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||
# if radius_constant <= 0:
|
||||
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
# radius_constant
|
||||
# ))
|
||||
#
|
||||
# self.C = C
|
||||
# self.dtype = dtype
|
||||
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
# super(StrongConvexHuber, self).__init__(reduction=reduction,
|
||||
# name=name,
|
||||
# from_logits=from_logits,
|
||||
# label_smoothing=label_smoothing,
|
||||
# )
|
||||
# self.radius_constant = radius_constant
|
||||
#
|
||||
# def call(self, y_true, y_pred):
|
||||
# """Compute loss
|
||||
#
|
||||
# Args:
|
||||
# y_true: Ground truth values.
|
||||
# y_pred: The predicted values.
|
||||
#
|
||||
# Returns:
|
||||
# Loss values per sample.
|
||||
# """
|
||||
# loss = super()
|
||||
# loss = loss * self.C
|
||||
# return loss
|
||||
#
|
||||
# def radius(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.radius_constant / self.reg_lambda
|
||||
#
|
||||
# def gamma(self):
|
||||
# """See super class.
|
||||
# """
|
||||
# return self.reg_lambda
|
||||
#
|
||||
# def beta(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda
|
||||
#
|
||||
# def lipchitz_constant(self, class_weight):
|
||||
# """See super class.
|
||||
# """
|
||||
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||
#
|
||||
# def kernel_regularizer(self):
|
||||
# """
|
||||
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||
# this loss function to be strongly convex.
|
||||
# :return:
|
||||
# """
|
||||
# return L1L2(l2=self.reg_lambda)
|
||||
|
|
|
@ -1,3 +1,325 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit testing for loss.py"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.keras.optimizer_v2 import adagrad
|
||||
from tensorflow.python.keras.optimizer_v2 import gradient_descent
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.framework import test_util
|
||||
from privacy.bolton import model
|
||||
from privacy.bolton.loss import StrongConvexBinaryCrossentropy
|
||||
from privacy.bolton.loss import StrongConvexHuber
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
from tensorflow.python.keras.regularizers import L1L2
|
||||
|
||||
|
||||
class StrongConvexTests(keras_parameterized.TestCase):
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'beta not implemented',
|
||||
'fn': 'beta',
|
||||
'args': [1]},
|
||||
{'testcase_name': 'gamma not implemented',
|
||||
'fn': 'gamma',
|
||||
'args': []},
|
||||
{'testcase_name': 'lipchitz not implemented',
|
||||
'fn': 'lipchitz_constant',
|
||||
'args': [1]},
|
||||
{'testcase_name': 'radius not implemented',
|
||||
'fn': 'radius',
|
||||
'args': []},
|
||||
])
|
||||
def test_not_implemented(self, fn, args):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
loss = StrongConvexMixin()
|
||||
getattr(loss, fn, None)(*args)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'radius not implemented',
|
||||
'fn': 'kernel_regularizer',
|
||||
'args': []},
|
||||
])
|
||||
def test_return_none(self, fn, args):
|
||||
loss = StrongConvexMixin()
|
||||
ret = getattr(loss, fn, None)(*args)
|
||||
self.assertEqual(ret, None)
|
||||
|
||||
|
||||
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||
"""tests for BinaryCrossesntropy StrongConvex loss"""
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'normal',
|
||||
'reg_lambda': 1,
|
||||
'c': 1,
|
||||
'radius_constant': 1
|
||||
},
|
||||
])
|
||||
def test_init_params(self, reg_lambda, c, radius_constant):
|
||||
# test valid domains for each variable
|
||||
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
|
||||
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'negative c',
|
||||
'reg_lambda': 1,
|
||||
'c': -1,
|
||||
'radius_constant': 1
|
||||
},
|
||||
{'testcase_name': 'negative radius',
|
||||
'reg_lambda': 1,
|
||||
'c': 1,
|
||||
'radius_constant': -1
|
||||
},
|
||||
{'testcase_name': 'negative lambda',
|
||||
'reg_lambda': -1,
|
||||
'c': 1,
|
||||
'radius_constant': 1
|
||||
},
|
||||
])
|
||||
def test_bad_init_params(self, reg_lambda, c, radius_constant):
|
||||
# test valid domains for each variable
|
||||
with self.assertRaises(ValueError):
|
||||
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
# [] for compatibility with tensorflow loss calculation
|
||||
{'testcase_name': 'both positive',
|
||||
'logits': [10000],
|
||||
'y_true': [1],
|
||||
'result': 0,
|
||||
},
|
||||
{'testcase_name': 'positive gradient negative logits',
|
||||
'logits': [-10000],
|
||||
'y_true': [1],
|
||||
'result': 10000,
|
||||
},
|
||||
{'testcase_name': 'positivee gradient positive logits',
|
||||
'logits': [10000],
|
||||
'y_true': [0],
|
||||
'result': 10000,
|
||||
},
|
||||
{'testcase_name': 'both negative',
|
||||
'logits': [-10000],
|
||||
'y_true': [0],
|
||||
'result': 0
|
||||
},
|
||||
])
|
||||
def test_calculation(self, logits, y_true, result):
|
||||
logits = tf.Variable(logits, False, dtype=tf.float32)
|
||||
y_true = tf.Variable(y_true, False, dtype=tf.float32)
|
||||
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
|
||||
loss = loss(y_true, logits)
|
||||
self.assertEqual(loss.numpy(), result)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'beta',
|
||||
'init_args': [1, 1, 1],
|
||||
'fn': 'beta',
|
||||
'args': [1],
|
||||
'result': tf.constant(2, dtype=tf.float32)
|
||||
},
|
||||
{'testcase_name': 'gamma',
|
||||
'fn': 'gamma',
|
||||
'init_args': [1, 1, 1],
|
||||
'args': [],
|
||||
'result': tf.constant(1, dtype=tf.float32),
|
||||
},
|
||||
{'testcase_name': 'lipchitz constant',
|
||||
'fn': 'lipchitz_constant',
|
||||
'init_args': [1, 1, 1],
|
||||
'args': [1],
|
||||
'result': tf.constant(2, dtype=tf.float32),
|
||||
},
|
||||
{'testcase_name': 'kernel regularizer',
|
||||
'fn': 'kernel_regularizer',
|
||||
'init_args': [1, 1, 1],
|
||||
'args': [],
|
||||
'result': L1L2(l2=1),
|
||||
},
|
||||
])
|
||||
def test_fns(self, init_args, fn, args, result):
|
||||
loss = StrongConvexBinaryCrossentropy(*init_args)
|
||||
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
|
||||
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
|
||||
expected = expected.numpy()
|
||||
result = result.numpy()
|
||||
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
|
||||
expected = expected.l2
|
||||
result = result.l2
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
|
||||
class HuberTests(keras_parameterized.TestCase):
|
||||
"""tests for BinaryCrossesntropy StrongConvex loss"""
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'normal',
|
||||
'reg_lambda': 1,
|
||||
'c': 1,
|
||||
'radius_constant': 1,
|
||||
'delta': 1,
|
||||
},
|
||||
])
|
||||
def test_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||
# test valid domains for each variable
|
||||
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
|
||||
self.assertIsInstance(loss, StrongConvexHuber)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'negative c',
|
||||
'reg_lambda': 1,
|
||||
'c': -1,
|
||||
'radius_constant': 1,
|
||||
'delta': 1
|
||||
},
|
||||
{'testcase_name': 'negative radius',
|
||||
'reg_lambda': 1,
|
||||
'c': 1,
|
||||
'radius_constant': -1,
|
||||
'delta': 1
|
||||
},
|
||||
{'testcase_name': 'negative lambda',
|
||||
'reg_lambda': -1,
|
||||
'c': 1,
|
||||
'radius_constant': 1,
|
||||
'delta': 1
|
||||
},
|
||||
{'testcase_name': 'negative delta',
|
||||
'reg_lambda': -1,
|
||||
'c': 1,
|
||||
'radius_constant': 1,
|
||||
'delta': -1
|
||||
},
|
||||
])
|
||||
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||
# test valid domains for each variable
|
||||
with self.assertRaises(ValueError):
|
||||
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
|
||||
|
||||
# test the bounds and test varied delta's
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
|
||||
'logits': 2.1,
|
||||
'y_true': 1,
|
||||
'delta': 1,
|
||||
'result': 0,
|
||||
},
|
||||
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
|
||||
'logits': 1.9,
|
||||
'y_true': 1,
|
||||
'delta': 1,
|
||||
'result': 0.01*0.25,
|
||||
},
|
||||
{'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary',
|
||||
'logits': 0.1,
|
||||
'y_true': 1,
|
||||
'delta': 1,
|
||||
'result': 1.9**2 * 0.25,
|
||||
},
|
||||
{'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary',
|
||||
'logits': -0.1,
|
||||
'y_true': 1,
|
||||
'delta': 1,
|
||||
'result': 1.1,
|
||||
},
|
||||
{'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary',
|
||||
'logits': 3.1,
|
||||
'y_true': 1,
|
||||
'delta': 2,
|
||||
'result': 0,
|
||||
},
|
||||
{'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary',
|
||||
'logits': 2.9,
|
||||
'y_true': 1,
|
||||
'delta': 2,
|
||||
'result': 0.01*0.125,
|
||||
},
|
||||
{'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary',
|
||||
'logits': 1.1,
|
||||
'y_true': 1,
|
||||
'delta': 2,
|
||||
'result': 1.9**2 * 0.125,
|
||||
},
|
||||
{'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary',
|
||||
'logits': -1.1,
|
||||
'y_true': 1,
|
||||
'delta': 2,
|
||||
'result': 2.1,
|
||||
},
|
||||
{'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary',
|
||||
'logits': -2.1,
|
||||
'y_true': -1,
|
||||
'delta': 1,
|
||||
'result': 0,
|
||||
},
|
||||
])
|
||||
def test_calculation(self, logits, y_true, delta, result):
|
||||
logits = tf.Variable(logits, False, dtype=tf.float32)
|
||||
y_true = tf.Variable(y_true, False, dtype=tf.float32)
|
||||
loss = StrongConvexHuber(0.00001, 1, 1, delta)
|
||||
loss = loss(y_true, logits)
|
||||
self.assertAllClose(loss.numpy(), result)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'beta',
|
||||
'init_args': [1, 1, 1, 1],
|
||||
'fn': 'beta',
|
||||
'args': [1],
|
||||
'result': tf.Variable(1.5, dtype=tf.float32)
|
||||
},
|
||||
{'testcase_name': 'gamma',
|
||||
'fn': 'gamma',
|
||||
'init_args': [1, 1, 1, 1],
|
||||
'args': [],
|
||||
'result': tf.Variable(1, dtype=tf.float32),
|
||||
},
|
||||
{'testcase_name': 'lipchitz constant',
|
||||
'fn': 'lipchitz_constant',
|
||||
'init_args': [1, 1, 1, 1],
|
||||
'args': [1],
|
||||
'result': tf.Variable(2, dtype=tf.float32),
|
||||
},
|
||||
{'testcase_name': 'kernel regularizer',
|
||||
'fn': 'kernel_regularizer',
|
||||
'init_args': [1, 1, 1, 1],
|
||||
'args': [],
|
||||
'result': L1L2(l2=1),
|
||||
},
|
||||
])
|
||||
def test_fns(self, init_args, fn, args, result):
|
||||
loss = StrongConvexHuber(*init_args)
|
||||
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
|
||||
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
|
||||
expected = expected.numpy()
|
||||
result = result.numpy()
|
||||
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
|
||||
expected = expected.l2
|
||||
result = result.l2
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
|
@ -19,11 +19,12 @@ from __future__ import print_function
|
|||
import tensorflow as tf
|
||||
from tensorflow.python.keras.models import Model
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from privacy.bolton.loss import StrongConvexLoss
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from privacy.bolton.optimizer import Private
|
||||
|
||||
_accepted_distributions = ['laplace']
|
||||
|
||||
|
||||
class Bolton(Model):
|
||||
"""
|
||||
|
@ -33,12 +34,16 @@ class Bolton(Model):
|
|||
2. Projects weights to R after each batch
|
||||
3. Limits learning rate
|
||||
4. Use a strongly convex loss function (see compile)
|
||||
|
||||
For more details on the strong convexity requirements, see:
|
||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||
Descent-based Analytics by Xi Wu et. al.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_classes,
|
||||
epsilon,
|
||||
noise_distribution='laplace',
|
||||
weights_initializer=tf.initializers.GlorotUniform(),
|
||||
seed=1,
|
||||
dtype=tf.float32
|
||||
):
|
||||
|
@ -59,6 +64,7 @@ class Bolton(Model):
|
|||
2. Projects weights to R after each batch
|
||||
3. Limits learning rate
|
||||
"""
|
||||
|
||||
def on_train_batch_end(self, batch, logs=None):
|
||||
loss = self.model.loss
|
||||
self.model.optimizer.limit_learning_rate(
|
||||
|
@ -72,13 +78,17 @@ class Bolton(Model):
|
|||
loss = self.model.loss
|
||||
self.model._project_weights_to_r(loss.radius(), True)
|
||||
|
||||
if epsilon <= 0:
|
||||
raise ValueError('Detected epsilon: {0}. '
|
||||
'Valid range is 0 < epsilon <inf'.format(epsilon))
|
||||
|
||||
if noise_distribution not in _accepted_distributions:
|
||||
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
|
||||
'distributions'.format(noise_distribution,
|
||||
_accepted_distributions))
|
||||
|
||||
super(Bolton, self).__init__(name='bolton', dynamic=False)
|
||||
self.n_classes = n_classes
|
||||
self.output_layer = tf.keras.layers.Dense(
|
||||
self.n_classes,
|
||||
kernel_regularizer=tf.keras.regularizers.l2(),
|
||||
kernel_initializer=weights_initializer,
|
||||
)
|
||||
# if we do regularization here, we require the user to re-instantiate
|
||||
# the model each time they want to
|
||||
# change lambda, unless we standardize modifying it later at .compile
|
||||
|
@ -87,6 +97,7 @@ class Bolton(Model):
|
|||
self.epsilon = epsilon
|
||||
self.seed = seed
|
||||
self.__in_fit = False
|
||||
self._layers_instantiated = False
|
||||
self._callback = MyCustomCallback()
|
||||
self._dtype = dtype
|
||||
|
||||
|
@ -114,15 +125,24 @@ class Bolton(Model):
|
|||
"""See super class. Default optimizer used in Bolton method is SGD.
|
||||
|
||||
"""
|
||||
if not isinstance(loss, StrongConvexLoss):
|
||||
raise ValueError("Loss must be subclassed from StrongConvexLoss")
|
||||
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda()
|
||||
for key, val in StrongConvexMixin.__dict__.items():
|
||||
if callable(val) and getattr(loss, key, None) is None:
|
||||
raise ValueError("Please ensure you are passing a valid StrongConvex "
|
||||
"loss that has all the required methods "
|
||||
"implemented. "
|
||||
"Required method: {0} not found".format(key))
|
||||
if not self._layers_instantiated: # compile may be called multiple times
|
||||
kernel_intiializer = kwargs.get('kernel_initializer',
|
||||
tf.initializers.GlorotUniform)
|
||||
self.output_layer = tf.keras.layers.Dense(
|
||||
self.n_classes,
|
||||
kernel_regularizer=loss.kernel_regularizer(),
|
||||
kernel_initializer=kernel_intiializer(),
|
||||
)
|
||||
self._layers_instantiated = True
|
||||
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
|
||||
if not isinstance(optimizer, Private):
|
||||
optimizer = optimizers.get(optimizer)
|
||||
if isinstance(self.optimizer, trackable.Trackable):
|
||||
self._track_trackable(
|
||||
self.optimizer, name='optimizer', overwrite=True
|
||||
)
|
||||
optimizer = Private(optimizer)
|
||||
|
||||
super(Bolton, self).compile(optimizer,
|
||||
|
@ -149,21 +169,20 @@ class Bolton(Model):
|
|||
Returns:
|
||||
|
||||
"""
|
||||
data_size = None
|
||||
if n_samples is not None:
|
||||
data_size = n_samples
|
||||
elif hasattr(x, 'shape'):
|
||||
data_size = x.shape[0]
|
||||
elif hasattr(x, "__len__"):
|
||||
data_size = len(x)
|
||||
else:
|
||||
elif data_size is None:
|
||||
if n_samples is None:
|
||||
raise ValueError("Unable to detect the number of training "
|
||||
"samples and n_smaples was None. "
|
||||
"either pass a dataset with a .shape or "
|
||||
"__len__ attribute or explicitly pass the "
|
||||
"number of samples as n_smaples.")
|
||||
data_size = n_samples
|
||||
|
||||
for layer in self._layers:
|
||||
layer.kernel = layer.kernel + self._get_noise(
|
||||
self.noise_distribution,
|
||||
|
@ -294,8 +313,8 @@ class Bolton(Model):
|
|||
Calculates class weighting to be used in training. Can be on
|
||||
Args:
|
||||
class_weights: str specifying type, array giving weights, or None.
|
||||
class_counts: If class_weights is not None, then the number of
|
||||
samples for each class
|
||||
class_counts: If class_weights is not None, then an array of
|
||||
the number of samples for each class
|
||||
num_classes: If class_weights is not None, then the number of
|
||||
classes.
|
||||
Returns: class_weights as 1D tensor, to be passed to model's fit method.
|
||||
|
@ -313,10 +332,16 @@ class Bolton(Model):
|
|||
"or pass an array".format(class_weights,
|
||||
class_keys))
|
||||
if class_counts is None:
|
||||
raise ValueError("Class counts must be provided if using"
|
||||
raise ValueError("Class counts must be provided if using "
|
||||
"class_weights=%s" % class_weights)
|
||||
class_counts_shape = tf.Variable(class_counts,
|
||||
trainable=False,
|
||||
dtype=self._dtype).shape
|
||||
if len(class_counts_shape) != 1:
|
||||
raise ValueError('class counts must be a 1D array.'
|
||||
'Detected: {0}'.format(class_counts_shape))
|
||||
if num_classes is None:
|
||||
raise ValueError("Class counts must be provided if using"
|
||||
raise ValueError("num_classes must be provided if using "
|
||||
"class_weights=%s" % class_weights)
|
||||
elif class_weights is not None:
|
||||
if num_classes is None:
|
||||
|
@ -327,10 +352,13 @@ class Bolton(Model):
|
|||
class_weights = 1
|
||||
elif is_string and class_weights == 'balanced':
|
||||
num_samples = sum(class_counts)
|
||||
class_weights = tf.Variable(
|
||||
num_samples / (num_classes * class_counts),
|
||||
dtype=self._dtype
|
||||
)
|
||||
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
|
||||
class_counts,
|
||||
),
|
||||
self._dtype
|
||||
)
|
||||
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
|
||||
tf.Variable(weighted_counts, dtype=self._dtype)
|
||||
else:
|
||||
class_weights = _ops.convert_to_tensor_v2(class_weights)
|
||||
if len(class_weights.shape) != 1:
|
||||
|
@ -376,7 +404,7 @@ class Bolton(Model):
|
|||
distribution = distribution.lower()
|
||||
input_dim = self._layers[0].kernel.numpy().shape[0]
|
||||
loss = self.loss
|
||||
if distribution == 'laplace':
|
||||
if distribution == _accepted_distributions[0]: # laplace
|
||||
per_class_epsilon = self.epsilon / (self.n_classes)
|
||||
l2_sensitivity = (2 *
|
||||
loss.lipchitz_constant(self.class_weight)) / \
|
||||
|
@ -396,7 +424,8 @@ class Bolton(Model):
|
|||
alpha,
|
||||
beta=1 / beta,
|
||||
seed=1,
|
||||
dtype=self._dtype)
|
||||
dtype=self._dtype
|
||||
)
|
||||
return unit_vector * gamma
|
||||
raise NotImplementedError("distribution: {0} is not "
|
||||
"currently supported".format(distribution))
|
||||
raise NotImplementedError('Noise distribution: {0} is not '
|
||||
'a valid distribution'.format(distribution))
|
||||
|
|
|
@ -1,3 +1,494 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit testing for model.py"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from privacy.bolton import model
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
from tensorflow.python.keras.regularizers import L1L2
|
||||
|
||||
|
||||
class TestLoss(losses.Loss):
|
||||
"""Test loss function for testing Bolton model"""
|
||||
def __init__(self, reg_lambda, C, radius_constant, name='test'):
|
||||
super(TestLoss, self).__init__(name=name)
|
||||
self.reg_lambda = reg_lambda
|
||||
self.C = C
|
||||
self.radius_constant = radius_constant
|
||||
|
||||
def radius(self):
|
||||
"""Radius of R-Ball (value to normalize weights to after each batch)
|
||||
|
||||
Returns: radius
|
||||
|
||||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def gamma(self):
|
||||
""" Gamma strongly convex
|
||||
|
||||
Returns: gamma
|
||||
|
||||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def beta(self, class_weight):
|
||||
"""Beta smoothess
|
||||
|
||||
Args:
|
||||
class_weight: the class weights used.
|
||||
|
||||
Returns: Beta
|
||||
|
||||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def lipchitz_constant(self, class_weight):
|
||||
""" L lipchitz continuous
|
||||
|
||||
Args:
|
||||
class_weight: class weights used
|
||||
|
||||
Returns: L
|
||||
|
||||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def call(self, val0, val1):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
|
||||
|
||||
def max_class_weight(self, class_weight):
|
||||
if class_weight is None:
|
||||
return 1
|
||||
|
||||
def kernel_regularizer(self):
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
|
||||
|
||||
class TestOptimizer(OptimizerV2):
|
||||
"""Test optimizer used for testing Bolton model"""
|
||||
def __init__(self):
|
||||
super(TestOptimizer, self).__init__('test')
|
||||
|
||||
def compute_gradients(self):
|
||||
return 0
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
def _create_slots(self, var):
|
||||
pass
|
||||
|
||||
def _resource_apply_dense(self, grad, handle):
|
||||
return grad
|
||||
|
||||
def _resource_apply_sparse(self, grad, handle, indices):
|
||||
return grad
|
||||
|
||||
|
||||
class InitTests(keras_parameterized.TestCase):
|
||||
"""tests for keras model initialization"""
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'normal',
|
||||
'n_classes': 1,
|
||||
'epsilon': 1,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 1
|
||||
},
|
||||
{'testcase_name': 'extreme range',
|
||||
'n_classes': 5,
|
||||
'epsilon': 0.1,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 10
|
||||
},
|
||||
{'testcase_name': 'extreme range2',
|
||||
'n_classes': 50,
|
||||
'epsilon': 10,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 100
|
||||
},
|
||||
])
|
||||
def test_init_params(
|
||||
self, n_classes, epsilon, noise_distribution, seed):
|
||||
# test valid domains for each variable
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
seed
|
||||
)
|
||||
self.assertIsInstance(clf, model.Bolton)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'invalid noise',
|
||||
'n_classes': 1,
|
||||
'epsilon': 1,
|
||||
'noise_distribution': 'not_valid',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'invalid epsilon',
|
||||
'n_classes': 1,
|
||||
'epsilon': -1,
|
||||
'noise_distribution': 'laplace',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
])
|
||||
def test_bad_init_params(
|
||||
self, n_classes, epsilon, noise_distribution, weights_initializer):
|
||||
# test invalid domains for each variable, especially noise
|
||||
seed = 1
|
||||
with self.assertRaises(ValueError):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer,
|
||||
seed
|
||||
)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'string compile',
|
||||
'n_classes': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': 'adam',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'test compile',
|
||||
'n_classes': 100,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': TestOptimizer(),
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'invalid weights initializer',
|
||||
'n_classes': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': TestOptimizer(),
|
||||
'weights_initializer': 'not_valid',
|
||||
},
|
||||
])
|
||||
def test_compile(self, n_classes, loss, optimizer, weights_initializer):
|
||||
# test compilation of valid tf.optimizer and tf.loss
|
||||
epsilon = 1
|
||||
noise_distribution = 'laplace'
|
||||
with self.cached_session():
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer
|
||||
)
|
||||
clf.compile(optimizer, loss)
|
||||
self.assertEqual(clf.loss, loss)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'Not strong loss',
|
||||
'n_classes': 1,
|
||||
'loss': losses.BinaryCrossentropy(),
|
||||
'optimizer': 'adam',
|
||||
},
|
||||
{'testcase_name': 'Not valid optimizer',
|
||||
'n_classes': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': 'ada',
|
||||
}
|
||||
])
|
||||
def test_bad_compile(self, n_classes, loss, optimizer):
|
||||
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
||||
epsilon = 1
|
||||
noise_distribution = 'laplace'
|
||||
weights_initializer = tf.initializers.GlorotUniform()
|
||||
with self.cached_session():
|
||||
with self.assertRaises((ValueError, AttributeError)):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer
|
||||
)
|
||||
clf.compile(optimizer, loss)
|
||||
|
||||
|
||||
def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
|
||||
"""
|
||||
Creates a categorically encoded dataset (y is categorical).
|
||||
returns the specified dataset either as a static array or as a generator.
|
||||
Will have evenly split samples across each output class.
|
||||
Each output class will be a different point in the input space.
|
||||
|
||||
Args:
|
||||
n_samples: number of rows
|
||||
input_dim: input dimensionality
|
||||
n_classes: output dimensionality
|
||||
t: one of 'train', 'val', 'test'
|
||||
generator: False for array, True for generator
|
||||
Returns:
|
||||
X as (n_samples, input_dim), Y as (n_samples, n_classes)
|
||||
"""
|
||||
x_stack = []
|
||||
y_stack = []
|
||||
for i_class in range(n_classes):
|
||||
x_stack.append(
|
||||
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
|
||||
)
|
||||
y_stack.append(
|
||||
tf.constant(i_class, tf.float32, (n_samples, n_classes))
|
||||
)
|
||||
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
|
||||
if generator:
|
||||
dataset = tf.data.Dataset.from_tensor_slices(
|
||||
(x_set, y_set)
|
||||
)
|
||||
return dataset
|
||||
return x_set, y_set
|
||||
|
||||
def _do_fit(n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
epsilon,
|
||||
generator,
|
||||
batch_size,
|
||||
reset_n_samples,
|
||||
optimizer,
|
||||
loss,
|
||||
callbacks,
|
||||
distribution='laplace'):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
distribution
|
||||
)
|
||||
clf.compile(optimizer, loss)
|
||||
if generator:
|
||||
x = _cat_dataset(
|
||||
n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
generator=generator
|
||||
)
|
||||
y = None
|
||||
# x = x.batch(batch_size)
|
||||
x = x.shuffle(n_samples//2)
|
||||
batch_size = None
|
||||
else:
|
||||
x, y = _cat_dataset(n_samples, input_dim, n_classes, generator=generator)
|
||||
if reset_n_samples:
|
||||
n_samples = None
|
||||
|
||||
if callbacks is not None:
|
||||
callbacks = [callbacks]
|
||||
clf.fit(x,
|
||||
y,
|
||||
batch_size=batch_size,
|
||||
n_samples=n_samples,
|
||||
callbacks=callbacks
|
||||
)
|
||||
return clf
|
||||
|
||||
|
||||
class TestCallback(tf.keras.callbacks.Callback):
|
||||
pass
|
||||
|
||||
|
||||
class FitTests(keras_parameterized.TestCase):
|
||||
"""Test cases for keras model fitting"""
|
||||
|
||||
# @test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'iterator fit',
|
||||
'generator': False,
|
||||
'reset_n_samples': True,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'iterator fit no samples',
|
||||
'generator': False,
|
||||
'reset_n_samples': True,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'generator fit',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'with callbacks',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': TestCallback()
|
||||
},
|
||||
])
|
||||
def test_fit(self, generator, reset_n_samples, callbacks):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
batch_size = 1
|
||||
n_samples = 10
|
||||
clf = _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
|
||||
reset_n_samples, optimizer, loss, callbacks)
|
||||
self.assertEqual(hasattr(clf, '_layers'), True)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'generator fit',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': None
|
||||
},
|
||||
])
|
||||
def test_fit_gen(self, generator, reset_n_samples, callbacks):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
batch_size = 1
|
||||
n_samples = 10
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon
|
||||
)
|
||||
clf.compile(optimizer, loss)
|
||||
x = _cat_dataset(
|
||||
n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
generator=generator
|
||||
)
|
||||
x = x.batch(batch_size)
|
||||
x = x.shuffle(n_samples // 2)
|
||||
clf.fit_generator(x, n_samples=n_samples)
|
||||
self.assertEqual(hasattr(clf, '_layers'), True)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'iterator no n_samples',
|
||||
'generator': True,
|
||||
'reset_n_samples': True,
|
||||
'distribution': 'laplace'
|
||||
},
|
||||
{'testcase_name': 'invalid distribution',
|
||||
'generator': True,
|
||||
'reset_n_samples': True,
|
||||
'distribution': 'not_valid'
|
||||
},
|
||||
])
|
||||
def test_bad_fit(self, generator, reset_n_samples, distribution):
|
||||
with self.assertRaises(ValueError):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
batch_size = 1
|
||||
n_samples = 10
|
||||
_do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
|
||||
reset_n_samples, optimizer, loss, None, distribution)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'None class_weights',
|
||||
'class_weights': None,
|
||||
'class_counts': None,
|
||||
'num_classes': None,
|
||||
'result': 1},
|
||||
{'testcase_name': 'class weights array',
|
||||
'class_weights': [1, 1],
|
||||
'class_counts': [1, 1],
|
||||
'num_classes': 2,
|
||||
'result': [1, 1]},
|
||||
{'testcase_name': 'class weights balanced',
|
||||
'class_weights': 'balanced',
|
||||
'class_counts': [1, 1],
|
||||
'num_classes': 2,
|
||||
'result': [1, 1]},
|
||||
])
|
||||
def test_class_calculate(self,
|
||||
class_weights,
|
||||
class_counts,
|
||||
num_classes,
|
||||
result
|
||||
):
|
||||
clf = model.Bolton(1, 1)
|
||||
expected = clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes
|
||||
)
|
||||
|
||||
if hasattr(expected, 'numpy'):
|
||||
expected = expected.numpy()
|
||||
self.assertAllEqual(
|
||||
expected,
|
||||
result
|
||||
)
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'class weight not valid str',
|
||||
'class_weights': 'not_valid',
|
||||
'class_counts': 1,
|
||||
'num_classes': 1,
|
||||
'err_msg': "Detected string class_weights with value: not_valid"},
|
||||
{'testcase_name': 'no class counts',
|
||||
'class_weights': 'balanced',
|
||||
'class_counts': None,
|
||||
'num_classes': 1,
|
||||
'err_msg':
|
||||
"Class counts must be provided if using class_weights=balanced"},
|
||||
{'testcase_name': 'no num classes',
|
||||
'class_weights': 'balanced',
|
||||
'class_counts': [1],
|
||||
'num_classes': None,
|
||||
'err_msg':
|
||||
'num_classes must be provided if using class_weights=balanced'},
|
||||
{'testcase_name': 'class counts not array',
|
||||
'class_weights': 'balanced',
|
||||
'class_counts': 1,
|
||||
'num_classes': None,
|
||||
'err_msg': 'class counts must be a 1D array.'},
|
||||
{'testcase_name': 'class counts array, no num classes',
|
||||
'class_weights': [1],
|
||||
'class_counts': None,
|
||||
'num_classes': None,
|
||||
'err_msg': "You must pass a value for num_classes if"
|
||||
"creating an array of class_weights"},
|
||||
{'testcase_name': 'class counts array, improper shape',
|
||||
'class_weights': [[1], [1]],
|
||||
'class_counts': None,
|
||||
'num_classes': 2,
|
||||
'err_msg': "Detected class_weights shape"},
|
||||
{'testcase_name': 'class counts array, wrong number classes',
|
||||
'class_weights': [1, 1, 1],
|
||||
'class_counts': None,
|
||||
'num_classes': 2,
|
||||
'err_msg': "Detected array length:"},
|
||||
])
|
||||
def test_class_errors(self,
|
||||
class_weights,
|
||||
class_counts,
|
||||
num_classes,
|
||||
err_msg):
|
||||
clf = model.Bolton(1, 1)
|
||||
with self.assertRaisesRegexp(ValueError, err_msg):
|
||||
expected = clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
|
|
|
@ -29,6 +29,10 @@ class Private(optimizer_v2.OptimizerV2):
|
|||
as the visible optimizer to the tf model. No matter the optimizer
|
||||
passed, "Private" enables the bolton model to control the learning rate
|
||||
based on the strongly convex loss.
|
||||
|
||||
For more details on the strong convexity requirements, see:
|
||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||
Descent-based Analytics by Xi Wu et. al.
|
||||
"""
|
||||
def __init__(self,
|
||||
optimizer: optimizer_v2.OptimizerV2,
|
||||
|
@ -76,13 +80,10 @@ class Private(optimizer_v2.OptimizerV2):
|
|||
else:
|
||||
self.learning_rate = numerator / (gamma * t)
|
||||
|
||||
def from_config(self, config, custom_objects=None):
|
||||
def from_config(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.from_config(
|
||||
config,
|
||||
custom_objects=custom_objects
|
||||
)
|
||||
return self._internal_optimizer.from_config(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""return _internal_optimizer off self instance, and everything else
|
||||
|
@ -116,58 +117,37 @@ class Private(optimizer_v2.OptimizerV2):
|
|||
else:
|
||||
setattr(self._internal_optimizer, key, value)
|
||||
|
||||
def _resource_apply_dense(self, grad, handle):
|
||||
def _resource_apply_dense(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer._resource_apply_dense(grad, handle)
|
||||
return self._internal_optimizer._resource_apply_dense(*args, **kwargs)
|
||||
|
||||
def _resource_apply_sparse(self, grad, handle, indices):
|
||||
def _resource_apply_sparse(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer._resource_apply_sparse(
|
||||
grad,
|
||||
handle,
|
||||
indices
|
||||
)
|
||||
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs)
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.get_updates(loss, params)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name: str = None):
|
||||
def apply_gradients(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.apply_gradients(
|
||||
grads_and_vars,
|
||||
name=name
|
||||
)
|
||||
return self._internal_optimizer.apply_gradients(*args, **kwargs)
|
||||
|
||||
def minimize(self,
|
||||
loss,
|
||||
var_list,
|
||||
grad_loss: bool = None,
|
||||
name: str = None
|
||||
):
|
||||
def minimize(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.minimize(
|
||||
loss,
|
||||
var_list,
|
||||
grad_loss,
|
||||
name
|
||||
)
|
||||
return self._internal_optimizer.minimize(*args, **kwargs)
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||
def _compute_gradients(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer._compute_gradients(
|
||||
loss,
|
||||
var_list,
|
||||
grad_loss=grad_loss
|
||||
)
|
||||
return self._internal_optimizer._compute_gradients(*args, **kwargs)
|
||||
|
||||
def get_gradients(self, loss, params):
|
||||
def get_gradients(self, *args, **kwargs):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.get_gradients(loss, params)
|
||||
return self._internal_optimizer.get_gradients(*args, **kwargs)
|
||||
|
|
|
@ -1,9 +1,182 @@
|
|||
# Copyright 2018, The TensorFlow Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit testing for optimizer.py"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from privacy.bolton import model
|
||||
from privacy.bolton import optimizer as opt
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
|
||||
|
||||
class TestOptimizer(OptimizerV2):
|
||||
"""Optimizer used for testing the Private optimizer"""
|
||||
def __init__(self):
|
||||
super(TestOptimizer, self).__init__('test')
|
||||
self.not_private = 'test'
|
||||
self.iterations = tf.Variable(1, dtype=tf.float32)
|
||||
self._iterations = tf.Variable(1, dtype=tf.float32)
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||
return 'test'
|
||||
|
||||
def get_config(self):
|
||||
return 'test'
|
||||
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
return 'test'
|
||||
|
||||
def _create_slots(self):
|
||||
return 'test'
|
||||
|
||||
def _resource_apply_dense(self, grad, handle):
|
||||
return 'test'
|
||||
|
||||
def _resource_apply_sparse(self, grad, handle, indices):
|
||||
return 'test'
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
return 'test'
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
return 'test'
|
||||
|
||||
def minimize(self, loss, var_list, grad_loss=None, name=None):
|
||||
return 'test'
|
||||
|
||||
def get_gradients(self, loss, params):
|
||||
return 'test'
|
||||
|
||||
class PrivateTest(keras_parameterized.TestCase):
|
||||
"""Private Optimizer tests"""
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'branch True, beta',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [True,
|
||||
tf.Variable(2, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(0.5, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'branch True, gamma',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [True,
|
||||
tf.Variable(1, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(1, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'branch False, beta',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [False,
|
||||
tf.Variable(2, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(0.5, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'branch False, gamma',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [False,
|
||||
tf.Variable(1, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(1, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'getattr',
|
||||
'fn': '__getattr__',
|
||||
'args': ['dtype'],
|
||||
'result': tf.float32,
|
||||
'test_attr': None},
|
||||
])
|
||||
def test_fn(self, fn, args, result, test_attr):
|
||||
private = opt.Private(TestOptimizer())
|
||||
res = getattr(private, fn, None)(*args)
|
||||
if test_attr is not None:
|
||||
res = getattr(private, test_attr, None)
|
||||
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
|
||||
res = res.numpy()
|
||||
result = result.numpy()
|
||||
self.assertEqual(res, result)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'fn: get_updates',
|
||||
'fn': 'get_updates',
|
||||
'args': [0, 0]},
|
||||
{'testcase_name': 'fn: get_config',
|
||||
'fn': 'get_config',
|
||||
'args': []},
|
||||
{'testcase_name': 'fn: from_config',
|
||||
'fn': 'from_config',
|
||||
'args': [0]},
|
||||
{'testcase_name': 'fn: _resource_apply_dense',
|
||||
'fn': '_resource_apply_dense',
|
||||
'args': [1, 1]},
|
||||
{'testcase_name': 'fn: _resource_apply_sparse',
|
||||
'fn': '_resource_apply_sparse',
|
||||
'args': [1, 1, 1]},
|
||||
{'testcase_name': 'fn: apply_gradients',
|
||||
'fn': 'apply_gradients',
|
||||
'args': [1]},
|
||||
{'testcase_name': 'fn: minimize',
|
||||
'fn': 'minimize',
|
||||
'args': [1, 1]},
|
||||
{'testcase_name': 'fn: _compute_gradients',
|
||||
'fn': '_compute_gradients',
|
||||
'args': [1, 1]},
|
||||
{'testcase_name': 'fn: get_gradients',
|
||||
'fn': 'get_gradients',
|
||||
'args': [1, 1]},
|
||||
])
|
||||
def test_rerouted_function(self, fn, args):
|
||||
optimizer = TestOptimizer()
|
||||
optimizer = opt.Private(optimizer)
|
||||
self.assertEqual(
|
||||
getattr(optimizer, fn, lambda: 'fn not found')(*args),
|
||||
'test'
|
||||
)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'fn: limit_learning_rate',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [1, 1, 1]}
|
||||
])
|
||||
def test_not_reroute_fn(self, fn, args):
|
||||
optimizer = TestOptimizer()
|
||||
optimizer = opt.Private(optimizer)
|
||||
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
|
||||
'test')
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'attr: not_private',
|
||||
'attr': 'not_private'}
|
||||
])
|
||||
def test_reroute_attr(self, attr):
|
||||
internal_optimizer = TestOptimizer()
|
||||
optimizer = opt.Private(internal_optimizer)
|
||||
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'attr: _internal_optimizer',
|
||||
'attr': '_internal_optimizer'}
|
||||
])
|
||||
def test_not_reroute_attr(self, attr):
|
||||
internal_optimizer = TestOptimizer()
|
||||
optimizer = opt.Private(internal_optimizer)
|
||||
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in a new issue