Working bolton model without unit tests.

-- update to include pull request changes
changes include:
parameter renaming,
changing to mixin,
moving model to compile,
additional tests,
fixing huber loss
This commit is contained in:
Christopher Choquette Choo 2019-06-10 16:11:47 -04:00
parent 5f46927747
commit 751eaead54
7 changed files with 1472 additions and 169 deletions

View file

@ -10,5 +10,5 @@ if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts.
pass pass
else: else:
from privacy.bolton.model import Bolton from privacy.bolton.model import Bolton
from privacy.bolton.loss import Huber from privacy.bolton.loss import StrongConvexHuber
from privacy.bolton.loss import BinaryCrossentropy from privacy.bolton.loss import StrongConvexBinaryCrossentropy

View file

@ -20,56 +20,33 @@ import tensorflow as tf
from tensorflow.python.keras import losses from tensorflow.python.keras import losses
from tensorflow.python.keras.utils import losses_utils from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras.regularizers import L1L2
class StrongConvexLoss(losses.Loss): class StrongConvexMixin:
""" """
Strong Convex Loss base class for any loss function that will be used with Strong Convex Mixin base class for any loss function that will be used with
Bolton model. Subclasses must be strongly convex and implement the Bolton model. Subclasses must be strongly convex and implement the
associated constants. They must also conform to the requirements of tf losses associated constants. They must also conform to the requirements of tf losses
(see super class) (see super class).
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
""" """
def __init__(self,
reg_lambda: float,
c: float,
radius_constant: float = 1,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = None,
dtype=tf.float32,
**kwargs):
"""
Args:
reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts
as a global weight.
radius_constant: constant defining the length of the radius
reduction: reduction type to use. See super class
name: Name of the loss instance
dtype: tf datatype to use for tensor conversions.
"""
super(StrongConvexLoss, self).__init__(reduction=reduction,
name=name,
**kwargs)
self._sample_weight = tf.Variable(initial_value=c,
trainable=False,
dtype=tf.float32)
self._reg_lambda = reg_lambda
self.radius_constant = tf.Variable(initial_value=radius_constant,
trainable=False,
dtype=tf.float32)
self.dtype = dtype
def radius(self): def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch) """Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns: radius Returns: R
""" """
raise NotImplementedError("Radius not implemented for StrongConvex Loss" raise NotImplementedError("Radius not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def gamma(self): def gamma(self):
""" Gamma strongly convex """ Strongly convexity, gamma
Returns: gamma Returns: gamma
@ -78,7 +55,7 @@ class StrongConvexLoss(losses.Loss):
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def beta(self, class_weight): def beta(self, class_weight):
"""Beta smoothess """Smoothness, beta
Args: Args:
class_weight: the class weights used. class_weight: the class weights used.
@ -90,7 +67,7 @@ class StrongConvexLoss(losses.Loss):
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
""" L lipchitz continuous """Lipchitz constant, L
Args: Args:
class_weight: class weights used class_weight: class weights used
@ -102,43 +79,46 @@ class StrongConvexLoss(losses.Loss):
"StrongConvex Loss" "StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def reg_lambda(self, convert_to_tensor: bool = False): def kernel_regularizer(self):
""" returns the lambda weight regularization constant, as a tensor if """returns the kernel_regularizer to be used. Any subclass should override
desired this method if they want a kernel_regularizer (if required for
the loss function to be StronglyConvex
:return: None or kernel_regularizer layer
"""
return None
def max_class_weight(self, class_weight, dtype):
"""the maximum weighting in class weights (max value) as a scalar tensor
Args: Args:
convert_to_tensor: True to convert to tensor, False to leave as class_weight: class weights used
python numeric. dtype: the data type for tensor conversions.
Returns: reg_lambda Returns: maximum class weighting as tensor scalar
""" """
if convert_to_tensor: class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype)
return self._reg_lambda
def max_class_weight(self, class_weight):
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype)
return tf.math.reduce_max(class_weight) return tf.math.reduce_max(class_weight)
class Huber(StrongConvexLoss, losses.Huber): class StrongConvexHuber(losses.Huber, StrongConvexMixin):
"""Strong Convex version of huber loss using l2 weight regularization. """Strong Convex version of Huber loss using l2 weight regularization.
""" """
def __init__(self, def __init__(self,
reg_lambda: float, reg_lambda: float,
c: float, C: float,
radius_constant: float, radius_constant: float,
delta: float, delta: float,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = 'huber', name: str = 'huber',
dtype=tf.float32): dtype=tf.float32):
"""Constructor. Passes arguments to StrongConvexLoss and Huber Loss. """Constructor.
Args: Args:
reg_lambda: Weight regularization constant reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts C: Penalty parameter C of the loss term
as a global weight.
radius_constant: constant defining the length of the radius radius_constant: constant defining the length of the radius
delta: delta value in huber loss. When to switch from quadratic to delta: delta value in huber loss. When to switch from quadratic to
absolute deviation. absolute deviation.
@ -149,15 +129,22 @@ class Huber(StrongConvexLoss, losses.Huber):
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
# self.delta = tf.Variable(initial_value=delta, trainable=False) if C <= 0:
super(Huber, self).__init__( raise ValueError('c: {0}, should be >= 0'.format(C))
reg_lambda, if reg_lambda <= 0:
c, raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
radius_constant, if radius_constant <= 0:
raise ValueError('radius_constant: {0}, should be >= 0'.format(
radius_constant
))
self.C = C
self.radius_constant = radius_constant
self.dtype = dtype
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexHuber, self).__init__(
delta=delta, delta=delta,
name=name, name=name,
reduction=reduction, reduction=reduction,
dtype=dtype
) )
def call(self, y_true, y_pred): def call(self, y_true, y_pred):
@ -170,46 +157,73 @@ class Huber(StrongConvexLoss, losses.Huber):
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \ # return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
self._sample_weight h = self._fn_kwargs['delta']
z = y_pred * y_true
one = tf.constant(1, dtype=self.dtype)
four = tf.constant(4, dtype=self.dtype)
if z > one + h:
return z - z
elif tf.math.abs(one - z) <= h:
return one / (four * h) * tf.math.pow(one + h - z, 2)
elif z < one - h:
return one - z
else:
raise ValueError('')
def radius(self): def radius(self):
"""See super class. """See super class.
""" """
return self.radius_constant / self.reg_lambda(True) return self.radius_constant / self.reg_lambda
def gamma(self): def gamma(self):
"""See super class. """See super class.
""" """
return self.reg_lambda(True) return self.reg_lambda
def beta(self, class_weight): def beta(self, class_weight):
"""See super class. """See super class.
""" """
max_class_weight = self.max_class_weight(class_weight) max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self._sample_weight * max_class_weight / \ delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'],
(self.delta * tf.Variable(initial_value=2, trainable=False)) + \ dtype=self.dtype
self.reg_lambda(True) )
return self.C * max_class_weight / (delta *
tf.constant(2, dtype=self.dtype)) + \
self.reg_lambda
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
"""See super class. """See super class.
""" """
# if class_weight is provided, # if class_weight is provided,
# it should be a vector of the same size of number of classes # it should be a vector of the same size of number of classes
max_class_weight = self.max_class_weight(class_weight) max_class_weight = self.max_class_weight(class_weight, self.dtype)
lc = self._sample_weight * max_class_weight + \ lc = self.C * max_class_weight + \
self.reg_lambda(True) * self.radius() self.reg_lambda * self.radius()
return lc return lc
def kernel_regularizer(self):
"""
l2 loss using reg_lambda as the l2 term (as desired). Required for
this loss function to be strongly convex.
:return:
"""
return L1L2(l2=self.reg_lambda)
class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
class StrongConvexBinaryCrossentropy(
losses.BinaryCrossentropy,
StrongConvexMixin
):
""" """
Strong Convex version of BinaryCrossentropy loss using l2 weight Strong Convex version of BinaryCrossentropy loss using l2 weight
regularization. regularization.
""" """
def __init__(self, def __init__(self,
reg_lambda: float, reg_lambda: float,
c: float, C: float,
radius_constant: float, radius_constant: float,
from_logits: bool = True, from_logits: bool = True,
label_smoothing: float = 0, label_smoothing: float = 0,
@ -219,8 +233,7 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
""" """
Args: Args:
reg_lambda: Weight regularization constant reg_lambda: Weight regularization constant
c: Additional constant for strongly convex convergence. Acts C: Penalty parameter C of the loss term
as a global weight.
radius_constant: constant defining the length of the radius radius_constant: constant defining the length of the radius
reduction: reduction type to use. See super class reduction: reduction type to use. See super class
label_smoothing: amount of smoothing to perform on labels label_smoothing: amount of smoothing to perform on labels
@ -228,15 +241,23 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
name: Name of the loss instance name: Name of the loss instance
dtype: tf datatype to use for tensor conversions. dtype: tf datatype to use for tensor conversions.
""" """
super(BinaryCrossentropy, self).__init__(reg_lambda, if reg_lambda <= 0:
c, raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
radius_constant, if C <= 0:
reduction=reduction, raise ValueError('c: {0}, should be >= 0'.format(C))
name=name, if radius_constant <= 0:
from_logits=from_logits, raise ValueError('radius_constant: {0}, should be >= 0'.format(
label_smoothing=label_smoothing, radius_constant
dtype=dtype ))
) self.dtype = dtype
self.C = C
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexBinaryCrossentropy, self).__init__(
reduction=reduction,
name=name,
from_logits=from_logits,
label_smoothing=label_smoothing,
)
self.radius_constant = radius_constant self.radius_constant = radius_constant
def call(self, y_true, y_pred): def call(self, y_true, y_pred):
@ -249,32 +270,319 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy):
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
loss = tf.nn.sigmoid_cross_entropy_with_logits( # loss = tf.nn.sigmoid_cross_entropy_with_logits(
labels=y_true, # labels=y_true,
logits=y_pred # logits=y_pred
) # )
loss = loss * self._sample_weight loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
loss = loss * self.C
return loss return loss
def radius(self): def radius(self):
"""See super class. """See super class.
""" """
return self.radius_constant / self.reg_lambda(True) return self.radius_constant / self.reg_lambda
def gamma(self): def gamma(self):
"""See super class. """See super class.
""" """
return self.reg_lambda(True) return self.reg_lambda
def beta(self, class_weight): def beta(self, class_weight):
"""See super class. """See super class.
""" """
max_class_weight = self.max_class_weight(class_weight) max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self._sample_weight * max_class_weight + self.reg_lambda(True) return self.C * max_class_weight + self.reg_lambda
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
"""See super class. """See super class.
""" """
max_class_weight = self.max_class_weight(class_weight) max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self._sample_weight * max_class_weight + \ return self.C * max_class_weight + self.reg_lambda * self.radius()
self.reg_lambda(True) * self.radius()
def kernel_regularizer(self):
"""
l2 loss using reg_lambda as the l2 term (as desired). Required for
this loss function to be strongly convex.
:return:
"""
return L1L2(l2=self.reg_lambda)
# class StrongConvexSparseCategoricalCrossentropy(
# losses.CategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexSparseCategoricalCrossentropy, self).__init__(
# reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)
#
# class StrongConvexSparseCategoricalCrossentropy(
# losses.SparseCategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexHuber, self).__init__(reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)
#
#
# class StrongConvexCategoricalCrossentropy(
# losses.CategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexHuber, self).__init__(reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)

View file

@ -1,3 +1,325 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for loss.py"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import adagrad
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras import losses
from tensorflow.python.framework import test_util
from privacy.bolton import model
from privacy.bolton.loss import StrongConvexBinaryCrossentropy
from privacy.bolton.loss import StrongConvexHuber
from privacy.bolton.loss import StrongConvexMixin
from absl.testing import parameterized
from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
class StrongConvexTests(keras_parameterized.TestCase):
@parameterized.named_parameters([
{'testcase_name': 'beta not implemented',
'fn': 'beta',
'args': [1]},
{'testcase_name': 'gamma not implemented',
'fn': 'gamma',
'args': []},
{'testcase_name': 'lipchitz not implemented',
'fn': 'lipchitz_constant',
'args': [1]},
{'testcase_name': 'radius not implemented',
'fn': 'radius',
'args': []},
])
def test_not_implemented(self, fn, args):
with self.assertRaises(NotImplementedError):
loss = StrongConvexMixin()
getattr(loss, fn, None)(*args)
@parameterized.named_parameters([
{'testcase_name': 'radius not implemented',
'fn': 'kernel_regularizer',
'args': []},
])
def test_return_none(self, fn, args):
loss = StrongConvexMixin()
ret = getattr(loss, fn, None)(*args)
self.assertEqual(ret, None)
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss"""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1
},
])
def test_init_params(self, reg_lambda, c, radius_constant):
# test valid domains for each variable
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'c': -1,
'radius_constant': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'c': 1,
'radius_constant': -1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'c': 1,
'radius_constant': 1
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant):
# test valid domains for each variable
with self.assertRaises(ValueError):
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive',
'logits': [10000],
'y_true': [1],
'result': 0,
},
{'testcase_name': 'positive gradient negative logits',
'logits': [-10000],
'y_true': [1],
'result': 10000,
},
{'testcase_name': 'positivee gradient positive logits',
'logits': [10000],
'y_true': [0],
'result': 10000,
},
{'testcase_name': 'both negative',
'logits': [-10000],
'y_true': [0],
'result': 0
},
])
def test_calculation(self, logits, y_true, result):
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
loss = loss(y_true, logits)
self.assertEqual(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.constant(2, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1],
'args': [],
'result': tf.constant(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1],
'args': [1],
'result': tf.constant(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1],
'args': [],
'result': L1L2(l2=1),
},
])
def test_fns(self, init_args, fn, args, result):
loss = StrongConvexBinaryCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss"""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1,
'delta': 1,
},
])
def test_init_params(self, reg_lambda, c, radius_constant, delta):
# test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
self.assertIsInstance(loss, StrongConvexHuber)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'c': -1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'c': 1,
'radius_constant': -1,
'delta': 1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'c': 1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative delta',
'reg_lambda': -1,
'c': 1,
'radius_constant': 1,
'delta': -1
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
# test valid domains for each variable
with self.assertRaises(ValueError):
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
# test the bounds and test varied delta's
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
'logits': 2.1,
'y_true': 1,
'delta': 1,
'result': 0,
},
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
'logits': 1.9,
'y_true': 1,
'delta': 1,
'result': 0.01*0.25,
},
{'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary',
'logits': 0.1,
'y_true': 1,
'delta': 1,
'result': 1.9**2 * 0.25,
},
{'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary',
'logits': -0.1,
'y_true': 1,
'delta': 1,
'result': 1.1,
},
{'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary',
'logits': 3.1,
'y_true': 1,
'delta': 2,
'result': 0,
},
{'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary',
'logits': 2.9,
'y_true': 1,
'delta': 2,
'result': 0.01*0.125,
},
{'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary',
'logits': 1.1,
'y_true': 1,
'delta': 2,
'result': 1.9**2 * 0.125,
},
{'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary',
'logits': -1.1,
'y_true': 1,
'delta': 2,
'result': 2.1,
},
{'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary',
'logits': -2.1,
'y_true': -1,
'delta': 1,
'result': 0,
},
])
def test_calculation(self, logits, y_true, delta, result):
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexHuber(0.00001, 1, 1, delta)
loss = loss(y_true, logits)
self.assertAllClose(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.Variable(1.5, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1, 1],
'args': [],
'result': tf.Variable(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1, 1],
'args': [1],
'result': tf.Variable(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1, 1],
'args': [],
'result': L1L2(l2=1),
},
])
def test_fns(self, init_args, fn, args, result):
loss = StrongConvexHuber(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
if __name__ == '__main__':
tf.test.main()

View file

@ -19,11 +19,12 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.models import Model from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexLoss from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton.optimizer import Private from privacy.bolton.optimizer import Private
_accepted_distributions = ['laplace']
class Bolton(Model): class Bolton(Model):
""" """
@ -33,12 +34,16 @@ class Bolton(Model):
2. Projects weights to R after each batch 2. Projects weights to R after each batch
3. Limits learning rate 3. Limits learning rate
4. Use a strongly convex loss function (see compile) 4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
""" """
def __init__(self, def __init__(self,
n_classes, n_classes,
epsilon, epsilon,
noise_distribution='laplace', noise_distribution='laplace',
weights_initializer=tf.initializers.GlorotUniform(),
seed=1, seed=1,
dtype=tf.float32 dtype=tf.float32
): ):
@ -59,6 +64,7 @@ class Bolton(Model):
2. Projects weights to R after each batch 2. Projects weights to R after each batch
3. Limits learning rate 3. Limits learning rate
""" """
def on_train_batch_end(self, batch, logs=None): def on_train_batch_end(self, batch, logs=None):
loss = self.model.loss loss = self.model.loss
self.model.optimizer.limit_learning_rate( self.model.optimizer.limit_learning_rate(
@ -72,13 +78,17 @@ class Bolton(Model):
loss = self.model.loss loss = self.model.loss
self.model._project_weights_to_r(loss.radius(), True) self.model._project_weights_to_r(loss.radius(), True)
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
super(Bolton, self).__init__(name='bolton', dynamic=False) super(Bolton, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes self.n_classes = n_classes
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
kernel_regularizer=tf.keras.regularizers.l2(),
kernel_initializer=weights_initializer,
)
# if we do regularization here, we require the user to re-instantiate # if we do regularization here, we require the user to re-instantiate
# the model each time they want to # the model each time they want to
# change lambda, unless we standardize modifying it later at .compile # change lambda, unless we standardize modifying it later at .compile
@ -87,6 +97,7 @@ class Bolton(Model):
self.epsilon = epsilon self.epsilon = epsilon
self.seed = seed self.seed = seed
self.__in_fit = False self.__in_fit = False
self._layers_instantiated = False
self._callback = MyCustomCallback() self._callback = MyCustomCallback()
self._dtype = dtype self._dtype = dtype
@ -114,15 +125,24 @@ class Bolton(Model):
"""See super class. Default optimizer used in Bolton method is SGD. """See super class. Default optimizer used in Bolton method is SGD.
""" """
if not isinstance(loss, StrongConvexLoss): for key, val in StrongConvexMixin.__dict__.items():
raise ValueError("Loss must be subclassed from StrongConvexLoss") if callable(val) and getattr(loss, key, None) is None:
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda() raise ValueError("Please ensure you are passing a valid StrongConvex "
"loss that has all the required methods "
"implemented. "
"Required method: {0} not found".format(key))
if not self._layers_instantiated: # compile may be called multiple times
kernel_intiializer = kwargs.get('kernel_initializer',
tf.initializers.GlorotUniform)
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_intiializer(),
)
self._layers_instantiated = True
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
if not isinstance(optimizer, Private): if not isinstance(optimizer, Private):
optimizer = optimizers.get(optimizer) optimizer = optimizers.get(optimizer)
if isinstance(self.optimizer, trackable.Trackable):
self._track_trackable(
self.optimizer, name='optimizer', overwrite=True
)
optimizer = Private(optimizer) optimizer = Private(optimizer)
super(Bolton, self).compile(optimizer, super(Bolton, self).compile(optimizer,
@ -149,21 +169,20 @@ class Bolton(Model):
Returns: Returns:
""" """
data_size = None
if n_samples is not None: if n_samples is not None:
data_size = n_samples data_size = n_samples
elif hasattr(x, 'shape'): elif hasattr(x, 'shape'):
data_size = x.shape[0] data_size = x.shape[0]
elif hasattr(x, "__len__"): elif hasattr(x, "__len__"):
data_size = len(x) data_size = len(x)
else: elif data_size is None:
if n_samples is None: if n_samples is None:
raise ValueError("Unable to detect the number of training " raise ValueError("Unable to detect the number of training "
"samples and n_smaples was None. " "samples and n_smaples was None. "
"either pass a dataset with a .shape or " "either pass a dataset with a .shape or "
"__len__ attribute or explicitly pass the " "__len__ attribute or explicitly pass the "
"number of samples as n_smaples.") "number of samples as n_smaples.")
data_size = n_samples
for layer in self._layers: for layer in self._layers:
layer.kernel = layer.kernel + self._get_noise( layer.kernel = layer.kernel + self._get_noise(
self.noise_distribution, self.noise_distribution,
@ -294,8 +313,8 @@ class Bolton(Model):
Calculates class weighting to be used in training. Can be on Calculates class weighting to be used in training. Can be on
Args: Args:
class_weights: str specifying type, array giving weights, or None. class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then the number of class_counts: If class_weights is not None, then an array of
samples for each class the number of samples for each class
num_classes: If class_weights is not None, then the number of num_classes: If class_weights is not None, then the number of
classes. classes.
Returns: class_weights as 1D tensor, to be passed to model's fit method. Returns: class_weights as 1D tensor, to be passed to model's fit method.
@ -313,10 +332,16 @@ class Bolton(Model):
"or pass an array".format(class_weights, "or pass an array".format(class_weights,
class_keys)) class_keys))
if class_counts is None: if class_counts is None:
raise ValueError("Class counts must be provided if using" raise ValueError("Class counts must be provided if using "
"class_weights=%s" % class_weights) "class_weights=%s" % class_weights)
class_counts_shape = tf.Variable(class_counts,
trainable=False,
dtype=self._dtype).shape
if len(class_counts_shape) != 1:
raise ValueError('class counts must be a 1D array.'
'Detected: {0}'.format(class_counts_shape))
if num_classes is None: if num_classes is None:
raise ValueError("Class counts must be provided if using" raise ValueError("num_classes must be provided if using "
"class_weights=%s" % class_weights) "class_weights=%s" % class_weights)
elif class_weights is not None: elif class_weights is not None:
if num_classes is None: if num_classes is None:
@ -327,10 +352,13 @@ class Bolton(Model):
class_weights = 1 class_weights = 1
elif is_string and class_weights == 'balanced': elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts) num_samples = sum(class_counts)
class_weights = tf.Variable( weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
num_samples / (num_classes * class_counts), class_counts,
dtype=self._dtype ),
) self._dtype
)
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
tf.Variable(weighted_counts, dtype=self._dtype)
else: else:
class_weights = _ops.convert_to_tensor_v2(class_weights) class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1: if len(class_weights.shape) != 1:
@ -376,7 +404,7 @@ class Bolton(Model):
distribution = distribution.lower() distribution = distribution.lower()
input_dim = self._layers[0].kernel.numpy().shape[0] input_dim = self._layers[0].kernel.numpy().shape[0]
loss = self.loss loss = self.loss
if distribution == 'laplace': if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (self.n_classes) per_class_epsilon = self.epsilon / (self.n_classes)
l2_sensitivity = (2 * l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weight)) / \ loss.lipchitz_constant(self.class_weight)) / \
@ -396,7 +424,8 @@ class Bolton(Model):
alpha, alpha,
beta=1 / beta, beta=1 / beta,
seed=1, seed=1,
dtype=self._dtype) dtype=self._dtype
)
return unit_vector * gamma return unit_vector * gamma
raise NotImplementedError("distribution: {0} is not " raise NotImplementedError('Noise distribution: {0} is not '
"currently supported".format(distribution)) 'a valid distribution'.format(distribution))

View file

@ -1,3 +1,494 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for model.py"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import losses
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import test_util
from privacy.bolton import model
from privacy.bolton.loss import StrongConvexMixin
from absl.testing import parameterized
from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
class TestLoss(losses.Loss):
"""Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = C
self.radius_constant = radius_constant
def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch)
Returns: radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self):
""" Gamma strongly convex
Returns: gamma
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight):
"""Beta smoothess
Args:
class_weight: the class weights used.
Returns: Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight):
""" L lipchitz continuous
Args:
class_weight: class weights used
Returns: L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
def max_class_weight(self, class_weight):
if class_weight is None:
return 1
def kernel_regularizer(self):
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing Bolton model"""
def __init__(self):
super(TestOptimizer, self).__init__('test')
def compute_gradients(self):
return 0
def get_config(self):
return {}
def _create_slots(self, var):
pass
def _resource_apply_dense(self, grad, handle):
return grad
def _resource_apply_sparse(self, grad, handle, indices):
return grad
class InitTests(keras_parameterized.TestCase):
"""tests for keras model initialization"""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'laplace',
'seed': 1
},
{'testcase_name': 'extreme range',
'n_classes': 5,
'epsilon': 0.1,
'noise_distribution': 'laplace',
'seed': 10
},
{'testcase_name': 'extreme range2',
'n_classes': 50,
'epsilon': 10,
'noise_distribution': 'laplace',
'seed': 100
},
])
def test_init_params(
self, n_classes, epsilon, noise_distribution, seed):
# test valid domains for each variable
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
seed
)
self.assertIsInstance(clf, model.Bolton)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'not_valid',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid epsilon',
'n_classes': 1,
'epsilon': -1,
'noise_distribution': 'laplace',
'weights_initializer': tf.initializers.GlorotUniform(),
},
])
def test_bad_init_params(
self, n_classes, epsilon, noise_distribution, weights_initializer):
# test invalid domains for each variable, especially noise
seed = 1
with self.assertRaises(ValueError):
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer,
seed
)
@parameterized.named_parameters([
{'testcase_name': 'string compile',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'adam',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'test compile',
'n_classes': 100,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid weights initializer',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
'weights_initializer': 'not_valid',
},
])
def test_compile(self, n_classes, loss, optimizer, weights_initializer):
# test compilation of valid tf.optimizer and tf.loss
epsilon = 1
noise_distribution = 'laplace'
with self.cached_session():
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer
)
clf.compile(optimizer, loss)
self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([
{'testcase_name': 'Not strong loss',
'n_classes': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
])
def test_bad_compile(self, n_classes, loss, optimizer):
# test compilaton of invalid tf.optimizer and non instantiated loss.
epsilon = 1
noise_distribution = 'laplace'
weights_initializer = tf.initializers.GlorotUniform()
with self.cached_session():
with self.assertRaises((ValueError, AttributeError)):
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer
)
clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
"""
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
Will have evenly split samples across each output class.
Each output class will be a different point in the input space.
Args:
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
t: one of 'train', 'val', 'test'
generator: False for array, True for generator
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_classes)
"""
x_stack = []
y_stack = []
for i_class in range(n_classes):
x_stack.append(
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
)
y_stack.append(
tf.constant(i_class, tf.float32, (n_samples, n_classes))
)
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
if generator:
dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set)
)
return dataset
return x_set, y_set
def _do_fit(n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
callbacks,
distribution='laplace'):
clf = model.Bolton(n_classes,
epsilon,
distribution
)
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
generator=generator
)
y = None
# x = x.batch(batch_size)
x = x.shuffle(n_samples//2)
batch_size = None
else:
x, y = _cat_dataset(n_samples, input_dim, n_classes, generator=generator)
if reset_n_samples:
n_samples = None
if callbacks is not None:
callbacks = [callbacks]
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
callbacks=callbacks
)
return clf
class TestCallback(tf.keras.callbacks.Callback):
pass
class FitTests(keras_parameterized.TestCase):
"""Test cases for keras model fitting"""
# @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
'callbacks': TestCallback()
},
])
def test_fit(self, generator, reset_n_samples, callbacks):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
clf = _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
reset_n_samples, optimizer, loss, callbacks)
self.assertEqual(hasattr(clf, '_layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
])
def test_fit_gen(self, generator, reset_n_samples, callbacks):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
clf = model.Bolton(n_classes,
epsilon
)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
generator=generator
)
x = x.batch(batch_size)
x = x.shuffle(n_samples // 2)
clf.fit_generator(x, n_samples=n_samples)
self.assertEqual(hasattr(clf, '_layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'iterator no n_samples',
'generator': True,
'reset_n_samples': True,
'distribution': 'laplace'
},
{'testcase_name': 'invalid distribution',
'generator': True,
'reset_n_samples': True,
'distribution': 'not_valid'
},
])
def test_bad_fit(self, generator, reset_n_samples, distribution):
with self.assertRaises(ValueError):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
_do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
reset_n_samples, optimizer, loss, None, distribution)
@parameterized.named_parameters([
{'testcase_name': 'None class_weights',
'class_weights': None,
'class_counts': None,
'num_classes': None,
'result': 1},
{'testcase_name': 'class weights array',
'class_weights': [1, 1],
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'class weights balanced',
'class_weights': 'balanced',
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
])
def test_class_calculate(self,
class_weights,
class_counts,
num_classes,
result
):
clf = model.Bolton(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes
)
if hasattr(expected, 'numpy'):
expected = expected.numpy()
self.assertAllEqual(
expected,
result
)
@parameterized.named_parameters([
{'testcase_name': 'class weight not valid str',
'class_weights': 'not_valid',
'class_counts': 1,
'num_classes': 1,
'err_msg': "Detected string class_weights with value: not_valid"},
{'testcase_name': 'no class counts',
'class_weights': 'balanced',
'class_counts': None,
'num_classes': 1,
'err_msg':
"Class counts must be provided if using class_weights=balanced"},
{'testcase_name': 'no num classes',
'class_weights': 'balanced',
'class_counts': [1],
'num_classes': None,
'err_msg':
'num_classes must be provided if using class_weights=balanced'},
{'testcase_name': 'class counts not array',
'class_weights': 'balanced',
'class_counts': 1,
'num_classes': None,
'err_msg': 'class counts must be a 1D array.'},
{'testcase_name': 'class counts array, no num classes',
'class_weights': [1],
'class_counts': None,
'num_classes': None,
'err_msg': "You must pass a value for num_classes if"
"creating an array of class_weights"},
{'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]],
'class_counts': None,
'num_classes': 2,
'err_msg': "Detected class_weights shape"},
{'testcase_name': 'class counts array, wrong number classes',
'class_weights': [1, 1, 1],
'class_counts': None,
'num_classes': 2,
'err_msg': "Detected array length:"},
])
def test_class_errors(self,
class_weights,
class_counts,
num_classes,
err_msg):
clf = model.Bolton(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg):
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes
)
if __name__ == '__main__':
tf.test.main()

View file

@ -29,6 +29,10 @@ class Private(optimizer_v2.OptimizerV2):
as the visible optimizer to the tf model. No matter the optimizer as the visible optimizer to the tf model. No matter the optimizer
passed, "Private" enables the bolton model to control the learning rate passed, "Private" enables the bolton model to control the learning rate
based on the strongly convex loss. based on the strongly convex loss.
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
""" """
def __init__(self, def __init__(self,
optimizer: optimizer_v2.OptimizerV2, optimizer: optimizer_v2.OptimizerV2,
@ -76,13 +80,10 @@ class Private(optimizer_v2.OptimizerV2):
else: else:
self.learning_rate = numerator / (gamma * t) self.learning_rate = numerator / (gamma * t)
def from_config(self, config, custom_objects=None): def from_config(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.from_config( return self._internal_optimizer.from_config(*args, **kwargs)
config,
custom_objects=custom_objects
)
def __getattr__(self, name): def __getattr__(self, name):
"""return _internal_optimizer off self instance, and everything else """return _internal_optimizer off self instance, and everything else
@ -116,58 +117,37 @@ class Private(optimizer_v2.OptimizerV2):
else: else:
setattr(self._internal_optimizer, key, value) setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, grad, handle): def _resource_apply_dense(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer._resource_apply_dense(grad, handle) return self._internal_optimizer._resource_apply_dense(*args, **kwargs)
def _resource_apply_sparse(self, grad, handle, indices): def _resource_apply_sparse(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer._resource_apply_sparse( return self._internal_optimizer._resource_apply_sparse(*args, **kwargs)
grad,
handle,
indices
)
def get_updates(self, loss, params): def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.get_updates(loss, params) return self._internal_optimizer.get_updates(loss, params)
def apply_gradients(self, grads_and_vars, name: str = None): def apply_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.apply_gradients( return self._internal_optimizer.apply_gradients(*args, **kwargs)
grads_and_vars,
name=name
)
def minimize(self, def minimize(self, *args, **kwargs):
loss,
var_list,
grad_loss: bool = None,
name: str = None
):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.minimize( return self._internal_optimizer.minimize(*args, **kwargs)
loss,
var_list,
grad_loss,
name
)
def _compute_gradients(self, loss, var_list, grad_loss=None): def _compute_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer._compute_gradients( return self._internal_optimizer._compute_gradients(*args, **kwargs)
loss,
var_list,
grad_loss=grad_loss
)
def get_gradients(self, loss, params): def get_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.get_gradients(loss, params) return self._internal_optimizer.get_gradients(*args, **kwargs)

View file

@ -1,9 +1,182 @@
# Copyright 2018, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for optimizer.py"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from privacy.bolton import model from privacy.bolton import model
from privacy.bolton import optimizer as opt
from absl.testing import parameterized
from absl.testing import absltest
class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the Private optimizer"""
def __init__(self):
super(TestOptimizer, self).__init__('test')
self.not_private = 'test'
self.iterations = tf.Variable(1, dtype=tf.float32)
self._iterations = tf.Variable(1, dtype=tf.float32)
def _compute_gradients(self, loss, var_list, grad_loss=None):
return 'test'
def get_config(self):
return 'test'
def from_config(cls, config, custom_objects=None):
return 'test'
def _create_slots(self):
return 'test'
def _resource_apply_dense(self, grad, handle):
return 'test'
def _resource_apply_sparse(self, grad, handle, indices):
return 'test'
def get_updates(self, loss, params):
return 'test'
def apply_gradients(self, grads_and_vars, name=None):
return 'test'
def minimize(self, loss, var_list, grad_loss=None, name=None):
return 'test'
def get_gradients(self, loss, params):
return 'test'
class PrivateTest(keras_parameterized.TestCase):
"""Private Optimizer tests"""
@parameterized.named_parameters([
{'testcase_name': 'branch True, beta',
'fn': 'limit_learning_rate',
'args': [True,
tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch True, gamma',
'fn': 'limit_learning_rate',
'args': [True,
tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, beta',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, gamma',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'getattr',
'fn': '__getattr__',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
])
def test_fn(self, fn, args, result, test_attr):
private = opt.Private(TestOptimizer())
res = getattr(private, fn, None)(*args)
if test_attr is not None:
res = getattr(private, test_attr, None)
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
res = res.numpy()
result = result.numpy()
self.assertEqual(res, result)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_updates',
'fn': 'get_updates',
'args': [0, 0]},
{'testcase_name': 'fn: get_config',
'fn': 'get_config',
'args': []},
{'testcase_name': 'fn: from_config',
'fn': 'from_config',
'args': [0]},
{'testcase_name': 'fn: _resource_apply_dense',
'fn': '_resource_apply_dense',
'args': [1, 1]},
{'testcase_name': 'fn: _resource_apply_sparse',
'fn': '_resource_apply_sparse',
'args': [1, 1, 1]},
{'testcase_name': 'fn: apply_gradients',
'fn': 'apply_gradients',
'args': [1]},
{'testcase_name': 'fn: minimize',
'fn': 'minimize',
'args': [1, 1]},
{'testcase_name': 'fn: _compute_gradients',
'fn': '_compute_gradients',
'args': [1, 1]},
{'testcase_name': 'fn: get_gradients',
'fn': 'get_gradients',
'args': [1, 1]},
])
def test_rerouted_function(self, fn, args):
optimizer = TestOptimizer()
optimizer = opt.Private(optimizer)
self.assertEqual(
getattr(optimizer, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([
{'testcase_name': 'fn: limit_learning_rate',
'fn': 'limit_learning_rate',
'args': [1, 1, 1]}
])
def test_not_reroute_fn(self, fn, args):
optimizer = TestOptimizer()
optimizer = opt.Private(optimizer)
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
'test')
@parameterized.named_parameters([
{'testcase_name': 'attr: not_private',
'attr': 'not_private'}
])
def test_reroute_attr(self, attr):
internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
@parameterized.named_parameters([
{'testcase_name': 'attr: _internal_optimizer',
'attr': '_internal_optimizer'}
])
def test_not_reroute_attr(self, attr):
internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
if __name__ == '__main__':
test.main()