forked from 626_privacy/tensorflow_privacy
commit
eab43e8294
9 changed files with 681 additions and 638 deletions
|
@ -1,3 +1,17 @@
|
||||||
|
# Copyright 2019, The TensorFlow Privacy Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Bolton Method for privacy."""
|
||||||
import sys
|
import sys
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -12,4 +26,4 @@ else:
|
||||||
from privacy.bolton.models import BoltonModel
|
from privacy.bolton.models import BoltonModel
|
||||||
from privacy.bolton.optimizers import Bolton
|
from privacy.bolton.optimizers import Bolton
|
||||||
from privacy.bolton.losses import StrongConvexHuber
|
from privacy.bolton.losses import StrongConvexHuber
|
||||||
from privacy.bolton.losses import StrongConvexBinaryCrossentropy
|
from privacy.bolton.losses import StrongConvexBinaryCrossentropy
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -40,53 +40,47 @@ class StrongConvexMixin:
|
||||||
"""Radius, R, of the hypothesis space W.
|
"""Radius, R, of the hypothesis space W.
|
||||||
W is a convex set that forms the hypothesis space.
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
Returns: R
|
Returns:
|
||||||
|
R
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
|
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
|
||||||
"function: %s" % str(self.__class__.__name__))
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
""" Strongly convexity, gamma
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
|
||||||
Returns: gamma
|
|
||||||
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
|
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
|
||||||
"function: %s" % str(self.__class__.__name__))
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
def beta(self, class_weight):
|
def beta(self, class_weight):
|
||||||
"""Smoothness, beta
|
"""Smoothness, beta.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: the class weights as scalar or 1d tensor, where its
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
dimensionality is equal to the number of outputs.
|
dimensionality is equal to the number of outputs.
|
||||||
|
|
||||||
Returns: Beta
|
Returns:
|
||||||
|
Beta
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
|
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
|
||||||
"function: %s" % str(self.__class__.__name__))
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
def lipchitz_constant(self, class_weight):
|
def lipchitz_constant(self, class_weight):
|
||||||
"""Lipchitz constant, L
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: class weights used
|
class_weight: class weights used
|
||||||
|
|
||||||
Returns: L
|
Returns: L
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("lipchitz constant not implemented for "
|
raise NotImplementedError("lipchitz constant not implemented for "
|
||||||
"StrongConvex Loss"
|
"StrongConvex Loss"
|
||||||
"function: %s" % str(self.__class__.__name__))
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
def kernel_regularizer(self):
|
def kernel_regularizer(self):
|
||||||
"""returns the kernel_regularizer to be used. Any subclass should override
|
"""Returns the kernel_regularizer to be used.
|
||||||
this method if they want a kernel_regularizer (if required for
|
|
||||||
the loss function to be StronglyConvex
|
|
||||||
|
|
||||||
:return: None or kernel_regularizer layer
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -97,16 +91,15 @@ class StrongConvexMixin:
|
||||||
class_weight: class weights used
|
class_weight: class weights used
|
||||||
dtype: the data type for tensor conversions.
|
dtype: the data type for tensor conversions.
|
||||||
|
|
||||||
Returns: maximum class weighting as tensor scalar
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
"""
|
"""
|
||||||
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
|
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
|
||||||
return tf.math.reduce_max(class_weight)
|
return tf.math.reduce_max(class_weight)
|
||||||
|
|
||||||
|
|
||||||
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
"""Strong Convex version of Huber loss using l2 weight regularization.
|
"""Strong Convex version of Huber loss using l2 weight regularization."""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
reg_lambda: float,
|
reg_lambda: float,
|
||||||
|
@ -153,7 +146,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
"""Compute loss
|
"""Computes loss
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true: Ground truth values. One hot encoded using -1 and 1.
|
y_true: Ground truth values. One hot encoded using -1 and 1.
|
||||||
|
@ -162,7 +155,6 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
Returns:
|
Returns:
|
||||||
Loss values per sample.
|
Loss values per sample.
|
||||||
"""
|
"""
|
||||||
# return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
|
|
||||||
h = self.delta
|
h = self.delta
|
||||||
z = y_pred * y_true
|
z = y_pred * y_true
|
||||||
one = tf.constant(1, dtype=self.dtype)
|
one = tf.constant(1, dtype=self.dtype)
|
||||||
|
@ -172,23 +164,18 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
|
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
|
||||||
elif tf.math.abs(one - z) <= h:
|
elif tf.math.abs(one - z) <= h:
|
||||||
return one / (four * h) * tf.math.pow(one + h - z, 2)
|
return one / (four * h) * tf.math.pow(one + h - z, 2)
|
||||||
elif z < one - h:
|
return one - z # elif: z < one - h
|
||||||
return one - z
|
|
||||||
raise ValueError('') # shouldn't be possible to get here.
|
|
||||||
|
|
||||||
def radius(self):
|
def radius(self):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
return self.radius_constant / self.reg_lambda
|
return self.radius_constant / self.reg_lambda
|
||||||
|
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
return self.reg_lambda
|
return self.reg_lambda
|
||||||
|
|
||||||
def beta(self, class_weight):
|
def beta(self, class_weight):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
delta = _ops.convert_to_tensor_v2(self.delta,
|
delta = _ops.convert_to_tensor_v2(self.delta,
|
||||||
dtype=self.dtype
|
dtype=self.dtype
|
||||||
|
@ -198,8 +185,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
self.reg_lambda
|
self.reg_lambda
|
||||||
|
|
||||||
def lipchitz_constant(self, class_weight):
|
def lipchitz_constant(self, class_weight):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
# if class_weight is provided,
|
# if class_weight is provided,
|
||||||
# it should be a vector of the same size of number of classes
|
# it should be a vector of the same size of number of classes
|
||||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
@ -208,10 +194,13 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
return lc
|
return lc
|
||||||
|
|
||||||
def kernel_regularizer(self):
|
def kernel_regularizer(self):
|
||||||
"""
|
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
|
||||||
l2 loss using reg_lambda as the l2 term (as desired). Required for
|
|
||||||
this loss function to be strongly convex.
|
L2 regularization is required for this loss function to be strongly convex.
|
||||||
:return:
|
|
||||||
|
Returns:
|
||||||
|
The L2 regularizer layer for this loss function, with regularizer constant
|
||||||
|
set to half the 0.5 * reg_lambda.
|
||||||
"""
|
"""
|
||||||
return L1L2(l2=self.reg_lambda/2)
|
return L1L2(l2=self.reg_lambda/2)
|
||||||
|
|
||||||
|
@ -220,10 +209,7 @@ class StrongConvexBinaryCrossentropy(
|
||||||
losses.BinaryCrossentropy,
|
losses.BinaryCrossentropy,
|
||||||
StrongConvexMixin
|
StrongConvexMixin
|
||||||
):
|
):
|
||||||
"""
|
"""Strongly Convex BinaryCrossentropy loss using l2 weight regularization."""
|
||||||
Strong Convex version of BinaryCrossentropy loss using l2 weight
|
|
||||||
regularization.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
reg_lambda: float,
|
reg_lambda: float,
|
||||||
|
@ -239,10 +225,12 @@ class StrongConvexBinaryCrossentropy(
|
||||||
C: Penalty parameter C of the loss term
|
C: Penalty parameter C of the loss term
|
||||||
radius_constant: constant defining the length of the radius
|
radius_constant: constant defining the length of the radius
|
||||||
reduction: reduction type to use. See super class
|
reduction: reduction type to use. See super class
|
||||||
|
from_logits: True if the input are unscaled logits. False if they are
|
||||||
|
already scaled.
|
||||||
label_smoothing: amount of smoothing to perform on labels
|
label_smoothing: amount of smoothing to perform on labels
|
||||||
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x).
|
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
|
||||||
Note, the impact of this parameter's effect on privacy
|
impact of this parameter's effect on privacy is not known and thus the
|
||||||
is not known and thus the default should be used.
|
default should be used.
|
||||||
name: Name of the loss instance
|
name: Name of the loss instance
|
||||||
dtype: tf datatype to use for tensor conversions.
|
dtype: tf datatype to use for tensor conversions.
|
||||||
"""
|
"""
|
||||||
|
@ -271,49 +259,322 @@ class StrongConvexBinaryCrossentropy(
|
||||||
self.radius_constant = radius_constant
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
def call(self, y_true, y_pred):
|
||||||
"""Compute loss
|
"""Computes loss
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y_true: Ground truth values.
|
y_true: Ground truth values.
|
||||||
y_pred: The predicted values.
|
y_pred: The predicted values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loss values per sample.
|
Loss values per sample.
|
||||||
"""
|
"""
|
||||||
# loss = tf.nn.sigmoid_cross_entropy_with_logits(
|
|
||||||
# labels=y_true,
|
|
||||||
# logits=y_pred
|
|
||||||
# )
|
|
||||||
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
|
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
|
||||||
loss = loss * self.C
|
loss = loss * self.C
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def radius(self):
|
def radius(self):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
return self.radius_constant / self.reg_lambda
|
return self.radius_constant / self.reg_lambda
|
||||||
|
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
return self.reg_lambda
|
return self.reg_lambda
|
||||||
|
|
||||||
def beta(self, class_weight):
|
def beta(self, class_weight):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
return self.C * max_class_weight + self.reg_lambda
|
return self.C * max_class_weight + self.reg_lambda
|
||||||
|
|
||||||
def lipchitz_constant(self, class_weight):
|
def lipchitz_constant(self, class_weight):
|
||||||
"""See super class.
|
"""See super class."""
|
||||||
"""
|
|
||||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
return self.C * max_class_weight + self.reg_lambda * self.radius()
|
return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||||
|
|
||||||
def kernel_regularizer(self):
|
def kernel_regularizer(self):
|
||||||
"""
|
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
|
||||||
l2 loss using reg_lambda as the l2 term (as desired). Required for
|
|
||||||
this loss function to be strongly convex.
|
L2 regularization is required for this loss function to be strongly convex.
|
||||||
:return:
|
|
||||||
|
Returns:
|
||||||
|
The L2 regularizer layer for this loss function, with regularizer constant
|
||||||
|
set to half the 0.5 * reg_lambda.
|
||||||
"""
|
"""
|
||||||
return L1L2(l2=self.reg_lambda/2)
|
return L1L2(l2=self.reg_lambda/2)
|
||||||
|
|
||||||
|
# class StrongConvexSparseCategoricalCrossentropy(
|
||||||
|
# losses.CategoricalCrossentropy,
|
||||||
|
# StrongConvexMixin
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
|
||||||
|
# regularization.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# def __init__(self,
|
||||||
|
# reg_lambda: float,
|
||||||
|
# C: float,
|
||||||
|
# radius_constant: float,
|
||||||
|
# from_logits: bool = True,
|
||||||
|
# label_smoothing: float = 0,
|
||||||
|
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
# name: str = 'binarycrossentropy',
|
||||||
|
# dtype=tf.float32):
|
||||||
|
# """
|
||||||
|
# Args:
|
||||||
|
# reg_lambda: Weight regularization constant
|
||||||
|
# C: Penalty parameter C of the loss term
|
||||||
|
# radius_constant: constant defining the length of the radius
|
||||||
|
# reduction: reduction type to use. See super class
|
||||||
|
# label_smoothing: amount of smoothing to perform on labels
|
||||||
|
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||||
|
# name: Name of the loss instance
|
||||||
|
# dtype: tf datatype to use for tensor conversions.
|
||||||
|
# """
|
||||||
|
# if reg_lambda <= 0:
|
||||||
|
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||||
|
# if C <= 0:
|
||||||
|
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||||
|
# if radius_constant <= 0:
|
||||||
|
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||||
|
# radius_constant
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# self.C = C
|
||||||
|
# self.dtype = dtype
|
||||||
|
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||||
|
# super(StrongConvexSparseCategoricalCrossentropy, self).__init__(
|
||||||
|
# reduction=reduction,
|
||||||
|
# name=name,
|
||||||
|
# from_logits=from_logits,
|
||||||
|
# label_smoothing=label_smoothing,
|
||||||
|
# )
|
||||||
|
# self.radius_constant = radius_constant
|
||||||
|
#
|
||||||
|
# def call(self, y_true, y_pred):
|
||||||
|
# """Compute loss
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# y_true: Ground truth values.
|
||||||
|
# y_pred: The predicted values.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# Loss values per sample.
|
||||||
|
# """
|
||||||
|
# loss = super()
|
||||||
|
# loss = loss * self.C
|
||||||
|
# return loss
|
||||||
|
#
|
||||||
|
# def radius(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.radius_constant / self.reg_lambda
|
||||||
|
#
|
||||||
|
# def gamma(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.reg_lambda
|
||||||
|
#
|
||||||
|
# def beta(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda
|
||||||
|
#
|
||||||
|
# def lipchitz_constant(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||||
|
#
|
||||||
|
# def kernel_regularizer(self):
|
||||||
|
# """
|
||||||
|
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||||
|
# this loss function to be strongly convex.
|
||||||
|
# :return:
|
||||||
|
# """
|
||||||
|
# return L1L2(l2=self.reg_lambda)
|
||||||
|
#
|
||||||
|
# class StrongConvexSparseCategoricalCrossentropy(
|
||||||
|
# losses.SparseCategoricalCrossentropy,
|
||||||
|
# StrongConvexMixin
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight
|
||||||
|
# regularization.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# def __init__(self,
|
||||||
|
# reg_lambda: float,
|
||||||
|
# C: float,
|
||||||
|
# radius_constant: float,
|
||||||
|
# from_logits: bool = True,
|
||||||
|
# label_smoothing: float = 0,
|
||||||
|
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
# name: str = 'binarycrossentropy',
|
||||||
|
# dtype=tf.float32):
|
||||||
|
# """
|
||||||
|
# Args:
|
||||||
|
# reg_lambda: Weight regularization constant
|
||||||
|
# C: Penalty parameter C of the loss term
|
||||||
|
# radius_constant: constant defining the length of the radius
|
||||||
|
# reduction: reduction type to use. See super class
|
||||||
|
# label_smoothing: amount of smoothing to perform on labels
|
||||||
|
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||||
|
# name: Name of the loss instance
|
||||||
|
# dtype: tf datatype to use for tensor conversions.
|
||||||
|
# """
|
||||||
|
# if reg_lambda <= 0:
|
||||||
|
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||||
|
# if C <= 0:
|
||||||
|
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||||
|
# if radius_constant <= 0:
|
||||||
|
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||||
|
# radius_constant
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# self.C = C
|
||||||
|
# self.dtype = dtype
|
||||||
|
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||||
|
# super(StrongConvexHuber, self).__init__(reduction=reduction,
|
||||||
|
# name=name,
|
||||||
|
# from_logits=from_logits,
|
||||||
|
# label_smoothing=label_smoothing,
|
||||||
|
# )
|
||||||
|
# self.radius_constant = radius_constant
|
||||||
|
#
|
||||||
|
# def call(self, y_true, y_pred):
|
||||||
|
# """Compute loss
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# y_true: Ground truth values.
|
||||||
|
# y_pred: The predicted values.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# Loss values per sample.
|
||||||
|
# """
|
||||||
|
# loss = super()
|
||||||
|
# loss = loss * self.C
|
||||||
|
# return loss
|
||||||
|
#
|
||||||
|
# def radius(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.radius_constant / self.reg_lambda
|
||||||
|
#
|
||||||
|
# def gamma(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.reg_lambda
|
||||||
|
#
|
||||||
|
# def beta(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda
|
||||||
|
#
|
||||||
|
# def lipchitz_constant(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||||
|
#
|
||||||
|
# def kernel_regularizer(self):
|
||||||
|
# """
|
||||||
|
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||||
|
# this loss function to be strongly convex.
|
||||||
|
# :return:
|
||||||
|
# """
|
||||||
|
# return L1L2(l2=self.reg_lambda)
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class StrongConvexCategoricalCrossentropy(
|
||||||
|
# losses.CategoricalCrossentropy,
|
||||||
|
# StrongConvexMixin
|
||||||
|
# ):
|
||||||
|
# """
|
||||||
|
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
|
||||||
|
# regularization.
|
||||||
|
# """
|
||||||
|
#
|
||||||
|
# def __init__(self,
|
||||||
|
# reg_lambda: float,
|
||||||
|
# C: float,
|
||||||
|
# radius_constant: float,
|
||||||
|
# from_logits: bool = True,
|
||||||
|
# label_smoothing: float = 0,
|
||||||
|
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
# name: str = 'binarycrossentropy',
|
||||||
|
# dtype=tf.float32):
|
||||||
|
# """
|
||||||
|
# Args:
|
||||||
|
# reg_lambda: Weight regularization constant
|
||||||
|
# C: Penalty parameter C of the loss term
|
||||||
|
# radius_constant: constant defining the length of the radius
|
||||||
|
# reduction: reduction type to use. See super class
|
||||||
|
# label_smoothing: amount of smoothing to perform on labels
|
||||||
|
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
|
||||||
|
# name: Name of the loss instance
|
||||||
|
# dtype: tf datatype to use for tensor conversions.
|
||||||
|
# """
|
||||||
|
# if reg_lambda <= 0:
|
||||||
|
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||||
|
# if C <= 0:
|
||||||
|
# raise ValueError('c: {0}, should be >= 0'.format(C))
|
||||||
|
# if radius_constant <= 0:
|
||||||
|
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||||
|
# radius_constant
|
||||||
|
# ))
|
||||||
|
#
|
||||||
|
# self.C = C
|
||||||
|
# self.dtype = dtype
|
||||||
|
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||||
|
# super(StrongConvexHuber, self).__init__(reduction=reduction,
|
||||||
|
# name=name,
|
||||||
|
# from_logits=from_logits,
|
||||||
|
# label_smoothing=label_smoothing,
|
||||||
|
# )
|
||||||
|
# self.radius_constant = radius_constant
|
||||||
|
#
|
||||||
|
# def call(self, y_true, y_pred):
|
||||||
|
# """Compute loss
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# y_true: Ground truth values.
|
||||||
|
# y_pred: The predicted values.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# Loss values per sample.
|
||||||
|
# """
|
||||||
|
# loss = super()
|
||||||
|
# loss = loss * self.C
|
||||||
|
# return loss
|
||||||
|
#
|
||||||
|
# def radius(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.radius_constant / self.reg_lambda
|
||||||
|
#
|
||||||
|
# def gamma(self):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# return self.reg_lambda
|
||||||
|
#
|
||||||
|
# def beta(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda
|
||||||
|
#
|
||||||
|
# def lipchitz_constant(self, class_weight):
|
||||||
|
# """See super class.
|
||||||
|
# """
|
||||||
|
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
# return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||||
|
#
|
||||||
|
# def kernel_regularizer(self):
|
||||||
|
# """
|
||||||
|
# l2 loss using reg_lambda as the l2 term (as desired). Required for
|
||||||
|
# this loss function to be strongly convex.
|
||||||
|
# :return:
|
||||||
|
# """
|
||||||
|
# return L1L2(l2=self.reg_lambda)
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,6 +17,9 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import StringIO
|
||||||
|
import sys
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
@ -27,6 +30,18 @@ from privacy.bolton.losses import StrongConvexHuber
|
||||||
from privacy.bolton.losses import StrongConvexMixin
|
from privacy.bolton.losses import StrongConvexMixin
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def captured_output():
|
||||||
|
"""Capture std_out and std_err within context."""
|
||||||
|
new_out, new_err = StringIO(), StringIO()
|
||||||
|
old_out, old_err = sys.stdout, sys.stderr
|
||||||
|
try:
|
||||||
|
sys.stdout, sys.stderr = new_out, new_err
|
||||||
|
yield sys.stdout, sys.stderr
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_out, old_err
|
||||||
|
|
||||||
|
|
||||||
class StrongConvexMixinTests(keras_parameterized.TestCase):
|
class StrongConvexMixinTests(keras_parameterized.TestCase):
|
||||||
"""Tests for the StrongConvexMixin"""
|
"""Tests for the StrongConvexMixin"""
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
|
@ -72,7 +87,7 @@ class StrongConvexMixinTests(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
"""tests for BinaryCrossesntropy StrongConvex loss"""
|
"""tests for BinaryCrossesntropy StrongConvex loss."""
|
||||||
|
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
{'testcase_name': 'normal',
|
{'testcase_name': 'normal',
|
||||||
|
@ -82,7 +97,8 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
}, # pylint: disable=invalid-name
|
}, # pylint: disable=invalid-name
|
||||||
])
|
])
|
||||||
def test_init_params(self, reg_lambda, C, radius_constant):
|
def test_init_params(self, reg_lambda, C, radius_constant):
|
||||||
"""Test initialization for given arguments
|
"""Test initialization for given arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reg_lambda: initialization value for reg_lambda arg
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
C: initialization value for C arg
|
C: initialization value for C arg
|
||||||
|
@ -111,6 +127,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_bad_init_params(self, reg_lambda, C, radius_constant):
|
def test_bad_init_params(self, reg_lambda, C, radius_constant):
|
||||||
"""Test invalid domain for given params. Should return ValueError
|
"""Test invalid domain for given params. Should return ValueError
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reg_lambda: initialization value for reg_lambda arg
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
C: initialization value for C arg
|
C: initialization value for C arg
|
||||||
|
@ -146,6 +163,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_calculation(self, logits, y_true, result):
|
def test_calculation(self, logits, y_true, result):
|
||||||
"""Test the call method to ensure it returns the correct value
|
"""Test the call method to ensure it returns the correct value
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: unscaled output of model
|
logits: unscaled output of model
|
||||||
y_true: label
|
y_true: label
|
||||||
|
@ -185,6 +203,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_fns(self, init_args, fn, args, result):
|
def test_fns(self, init_args, fn, args, result):
|
||||||
"""Test that fn of BinaryCrossentropy loss returns the correct result
|
"""Test that fn of BinaryCrossentropy loss returns the correct result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
init_args: init values for loss instance
|
init_args: init values for loss instance
|
||||||
fn: the fn to test
|
fn: the fn to test
|
||||||
|
@ -201,6 +220,29 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
result = result.l2
|
result = result.l2
|
||||||
self.assertEqual(expected, result)
|
self.assertEqual(expected, result)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'label_smoothing',
|
||||||
|
'init_args': [1, 1, 1, True, 0.1],
|
||||||
|
'fn': None,
|
||||||
|
'args': None,
|
||||||
|
'print_res': 'The impact of label smoothing on privacy is unknown.'
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_prints(self, init_args, fn, args, print_res):
|
||||||
|
"""Test logger warning from StrongConvexBinaryCrossentropy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_args: arguments to init the object with.
|
||||||
|
fn: function to test
|
||||||
|
args: arguments to above function
|
||||||
|
print_res: print result that should have been printed.
|
||||||
|
"""
|
||||||
|
with captured_output() as (out, err): # pylint: disable=unused-variable
|
||||||
|
loss = StrongConvexBinaryCrossentropy(*init_args)
|
||||||
|
if fn is not None:
|
||||||
|
getattr(loss, fn, lambda *arguments: print('error'))(*args)
|
||||||
|
self.assertRegexMatch(err.getvalue().strip(), [print_res])
|
||||||
|
|
||||||
|
|
||||||
class HuberTests(keras_parameterized.TestCase):
|
class HuberTests(keras_parameterized.TestCase):
|
||||||
"""tests for BinaryCrossesntropy StrongConvex loss"""
|
"""tests for BinaryCrossesntropy StrongConvex loss"""
|
||||||
|
@ -215,6 +257,7 @@ class HuberTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_init_params(self, reg_lambda, c, radius_constant, delta):
|
def test_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||||
"""Test initialization for given arguments
|
"""Test initialization for given arguments
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reg_lambda: initialization value for reg_lambda arg
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
C: initialization value for C arg
|
C: initialization value for C arg
|
||||||
|
@ -244,7 +287,7 @@ class HuberTests(keras_parameterized.TestCase):
|
||||||
'delta': 1
|
'delta': 1
|
||||||
},
|
},
|
||||||
{'testcase_name': 'negative delta',
|
{'testcase_name': 'negative delta',
|
||||||
'reg_lambda': -1,
|
'reg_lambda': 1,
|
||||||
'c': 1,
|
'c': 1,
|
||||||
'radius_constant': 1,
|
'radius_constant': 1,
|
||||||
'delta': -1
|
'delta': -1
|
||||||
|
@ -252,10 +295,12 @@ class HuberTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
|
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||||
"""Test invalid domain for given params. Should return ValueError
|
"""Test invalid domain for given params. Should return ValueError
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reg_lambda: initialization value for reg_lambda arg
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
C: initialization value for C arg
|
C: initialization value for C arg
|
||||||
radius_constant: initialization value for radius_constant arg
|
radius_constant: initialization value for radius_constant arg
|
||||||
|
delta: the delta parameter for the huber loss
|
||||||
"""
|
"""
|
||||||
# test valid domains for each variable
|
# test valid domains for each variable
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
@ -321,6 +366,7 @@ class HuberTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_calculation(self, logits, y_true, delta, result):
|
def test_calculation(self, logits, y_true, delta, result):
|
||||||
"""Test the call method to ensure it returns the correct value
|
"""Test the call method to ensure it returns the correct value
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: unscaled output of model
|
logits: unscaled output of model
|
||||||
y_true: label
|
y_true: label
|
||||||
|
@ -360,6 +406,7 @@ class HuberTests(keras_parameterized.TestCase):
|
||||||
])
|
])
|
||||||
def test_fns(self, init_args, fn, args, result):
|
def test_fns(self, init_args, fn, args, result):
|
||||||
"""Test that fn of BinaryCrossentropy loss returns the correct result
|
"""Test that fn of BinaryCrossentropy loss returns the correct result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
init_args: init values for loss instance
|
init_args: init values for loss instance
|
||||||
fn: the fn to test
|
fn: the fn to test
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -25,8 +25,11 @@ from privacy.bolton.optimizers import Bolton
|
||||||
|
|
||||||
|
|
||||||
class BoltonModel(Model):
|
class BoltonModel(Model):
|
||||||
"""
|
"""Bolton episilon-delta differential privacy model.
|
||||||
Bolton episilon-delta model
|
|
||||||
|
The privacy guarantees are dependent on the noise that is sampled. Please
|
||||||
|
see the paper linked below for more details.
|
||||||
|
|
||||||
Uses 4 key steps to achieve privacy guarantees:
|
Uses 4 key steps to achieve privacy guarantees:
|
||||||
1. Adds noise to weights after training (output perturbation).
|
1. Adds noise to weights after training (output perturbation).
|
||||||
2. Projects weights to R after each batch
|
2. Projects weights to R after each batch
|
||||||
|
@ -121,8 +124,9 @@ class BoltonModel(Model):
|
||||||
noise_distribution='laplace',
|
noise_distribution='laplace',
|
||||||
steps_per_epoch=None,
|
steps_per_epoch=None,
|
||||||
**kwargs): # pylint: disable=arguments-differ
|
**kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
|
"""Reroutes to super fit with Bolton delta-epsilon privacy requirements.
|
||||||
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
|
|
||||||
|
Note, inputs must be normalized s.t. ||x|| < 1.
|
||||||
Requirements are as follows:
|
Requirements are as follows:
|
||||||
1. Adds noise to weights after training (output perturbation).
|
1. Adds noise to weights after training (output perturbation).
|
||||||
2. Projects weights to R after each batch
|
2. Projects weights to R after each batch
|
||||||
|
@ -139,7 +143,6 @@ class BoltonModel(Model):
|
||||||
whose dim == n_classes.
|
whose dim == n_classes.
|
||||||
|
|
||||||
See the super method for descriptions on the rest of the arguments.
|
See the super method for descriptions on the rest of the arguments.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if class_weight is None:
|
if class_weight is None:
|
||||||
class_weight_ = self.calculate_class_weights(class_weight)
|
class_weight_ = self.calculate_class_weights(class_weight)
|
||||||
|
@ -237,8 +240,8 @@ class BoltonModel(Model):
|
||||||
class_counts=None,
|
class_counts=None,
|
||||||
num_classes=None
|
num_classes=None
|
||||||
):
|
):
|
||||||
"""
|
"""Calculates class weighting to be used in training.
|
||||||
Calculates class weighting to be used in training. Can be on
|
|
||||||
Args:
|
Args:
|
||||||
class_weights: str specifying type, array giving weights, or None.
|
class_weights: str specifying type, array giving weights, or None.
|
||||||
class_counts: If class_weights is not None, then an array of
|
class_counts: If class_weights is not None, then an array of
|
||||||
|
@ -246,7 +249,6 @@ class BoltonModel(Model):
|
||||||
num_classes: If class_weights is not None, then the number of
|
num_classes: If class_weights is not None, then the number of
|
||||||
classes.
|
classes.
|
||||||
Returns: class_weights as 1D tensor, to be passed to model's fit method.
|
Returns: class_weights as 1D tensor, to be passed to model's fit method.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Value checking
|
# Value checking
|
||||||
class_keys = ['balanced']
|
class_keys = ['balanced']
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -38,40 +38,36 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
self.radius_constant = radius_constant
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
def radius(self):
|
def radius(self):
|
||||||
"""Radius of R-Ball (value to normalize weights to after each batch)
|
"""Radius, R, of the hypothesis space W.
|
||||||
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
Returns: radius
|
Returns: radius
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
""" Gamma strongly convex
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
|
||||||
Returns: gamma
|
|
||||||
|
|
||||||
"""
|
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
def beta(self, class_weight): # pylint: disable=unused-argument
|
def beta(self, class_weight): # pylint: disable=unused-argument
|
||||||
"""Beta smoothess
|
"""Smoothness, beta.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: the class weights used.
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
|
dimensionality is equal to the number of outputs.
|
||||||
Returns: Beta
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Beta
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
||||||
""" L lipchitz continuous
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: class weights used
|
class_weight: class weights used
|
||||||
|
|
||||||
Returns: L
|
Returns: L
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
@ -83,11 +79,25 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def max_class_weight(self, class_weight):
|
def max_class_weight(self, class_weight):
|
||||||
|
"""the maximum weighting in class weights (max value) as a scalar tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
dtype: the data type for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
|
"""
|
||||||
if class_weight is None:
|
if class_weight is None:
|
||||||
return 1
|
return 1
|
||||||
raise ValueError('')
|
raise ValueError('')
|
||||||
|
|
||||||
def kernel_regularizer(self):
|
def kernel_regularizer(self):
|
||||||
|
"""Returns the kernel_regularizer to be used.
|
||||||
|
|
||||||
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
|
"""
|
||||||
return L1L2(l2=self.reg_lambda)
|
return L1L2(l2=self.reg_lambda)
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +123,7 @@ class TestOptimizer(OptimizerV2):
|
||||||
|
|
||||||
|
|
||||||
class InitTests(keras_parameterized.TestCase):
|
class InitTests(keras_parameterized.TestCase):
|
||||||
"""tests for keras model initialization"""
|
"""Tests for keras model initialization."""
|
||||||
|
|
||||||
@parameterized.named_parameters([
|
@parameterized.named_parameters([
|
||||||
{'testcase_name': 'normal',
|
{'testcase_name': 'normal',
|
||||||
|
@ -124,7 +134,7 @@ class InitTests(keras_parameterized.TestCase):
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
def test_init_params(self, n_outputs):
|
def test_init_params(self, n_outputs):
|
||||||
"""test initialization of BoltonModel
|
"""Test initialization of BoltonModel.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_outputs: number of output neurons
|
n_outputs: number of output neurons
|
||||||
|
@ -243,8 +253,7 @@ def _do_fit(n_samples,
|
||||||
optimizer,
|
optimizer,
|
||||||
loss,
|
loss,
|
||||||
distribution='laplace'):
|
distribution='laplace'):
|
||||||
"""Helper to instantiate necessary components for fitting and perform a model
|
"""Instantiate necessary components for fitting and perform a model fit.
|
||||||
fit.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_samples: number of samples in dataset
|
n_samples: number of samples in dataset
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -26,11 +26,9 @@ _accepted_distributions = ['laplace'] # implemented distributions for noising
|
||||||
|
|
||||||
|
|
||||||
class GammaBetaDecreasingStep(
|
class GammaBetaDecreasingStep(
|
||||||
optimizer_v2.learning_rate_schedule.LearningRateSchedule
|
optimizer_v2.learning_rate_schedule.LearningRateSchedule):
|
||||||
):
|
"""Computes LR as minimum of 1/beta and 1/(gamma * step) at each step.
|
||||||
"""
|
A required step for privacy guarantees.
|
||||||
Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step)
|
|
||||||
at each step. A required step for privacy guarantees.
|
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.is_init = False
|
self.is_init = False
|
||||||
|
@ -38,8 +36,7 @@ class GammaBetaDecreasingStep(
|
||||||
self.gamma = None
|
self.gamma = None
|
||||||
|
|
||||||
def __call__(self, step):
|
def __call__(self, step):
|
||||||
"""
|
"""Computes and returns the learning rate.
|
||||||
returns the learning rate
|
|
||||||
Args:
|
Args:
|
||||||
step: the current iteration number
|
step: the current iteration number
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -61,15 +58,14 @@ class GammaBetaDecreasingStep(
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
"""
|
"""Return config to setup the learning rate scheduler."""
|
||||||
config to setup the learning rate scheduler.
|
|
||||||
"""
|
|
||||||
return {'beta': self.beta, 'gamma': self.gamma}
|
return {'beta': self.beta, 'gamma': self.gamma}
|
||||||
|
|
||||||
def initialize(self, beta, gamma):
|
def initialize(self, beta, gamma):
|
||||||
"""setup the learning rate scheduler with the beta and gamma values provided
|
"""Setups scheduler with beta and gamma values from the loss function.
|
||||||
by the loss function. Meant to be used with .fit as the loss params may
|
|
||||||
depend on values passed to fit.
|
Meant to be used with .fit as the loss params may depend on values passed to
|
||||||
|
fit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
beta: Smoothness value. See StrongConvexMixin
|
beta: Smoothness value. See StrongConvexMixin
|
||||||
|
@ -80,37 +76,36 @@ class GammaBetaDecreasingStep(
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
|
|
||||||
def de_initialize(self):
|
def de_initialize(self):
|
||||||
"""De initialize the scheduler after fitting, in case another fit call has
|
"""De initialize post fit, as another fit call may use other parameters."""
|
||||||
different loss parameters.
|
|
||||||
"""
|
|
||||||
self.is_init = False
|
self.is_init = False
|
||||||
self.beta = None
|
self.beta = None
|
||||||
self.gamma = None
|
self.gamma = None
|
||||||
|
|
||||||
|
|
||||||
class Bolton(optimizer_v2.OptimizerV2):
|
class Bolton(optimizer_v2.OptimizerV2):
|
||||||
"""
|
"""Wrap another tf optimizer with Bolton privacy protocol.
|
||||||
Bolton optimizer wraps another tf optimizer to be used
|
|
||||||
as the visible optimizer to the tf model. No matter the optimizer
|
|
||||||
passed, "Bolton" enables the bolton model to control the learning rate
|
|
||||||
based on the strongly convex loss.
|
|
||||||
|
|
||||||
To use the Bolton method, you must:
|
Bolton optimizer wraps another tf optimizer to be used
|
||||||
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
|
as the visible optimizer to the tf model. No matter the optimizer
|
||||||
2. use it as a context manager around your .fit method internals.
|
passed, "Bolton" enables the bolton model to control the learning rate
|
||||||
|
based on the strongly convex loss.
|
||||||
|
|
||||||
This can be accomplished by the following:
|
To use the Bolton method, you must:
|
||||||
optimizer = tf.optimizers.SGD()
|
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
|
||||||
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
|
2. use it as a context manager around your .fit method internals.
|
||||||
bolton = Bolton(optimizer, loss)
|
|
||||||
with bolton(*args) as _:
|
|
||||||
model.fit()
|
|
||||||
The args required for the context manager can be found in the __call__
|
|
||||||
method.
|
|
||||||
|
|
||||||
For more details on the strong convexity requirements, see:
|
This can be accomplished by the following:
|
||||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
optimizer = tf.optimizers.SGD()
|
||||||
Descent-based Analytics by Xi Wu et. al.
|
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
|
||||||
|
bolton = Bolton(optimizer, loss)
|
||||||
|
with bolton(*args) as _:
|
||||||
|
model.fit()
|
||||||
|
The args required for the context manager can be found in the __call__
|
||||||
|
method.
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et. al.
|
||||||
"""
|
"""
|
||||||
def __init__(self, # pylint: disable=super-init-not-called
|
def __init__(self, # pylint: disable=super-init-not-called
|
||||||
optimizer: optimizer_v2.OptimizerV2,
|
optimizer: optimizer_v2.OptimizerV2,
|
||||||
|
@ -120,9 +115,9 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: Optimizer_v2 or subclass to be used as the optimizer
|
optimizer: Optimizer_v2 or subclass to be used as the optimizer
|
||||||
(wrapped).
|
(wrapped).
|
||||||
loss: StrongConvexLoss function that the model is being compiled with.
|
loss: StrongConvexLoss function that the model is being compiled with.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not isinstance(loss, StrongConvexMixin):
|
if not isinstance(loss, StrongConvexMixin):
|
||||||
|
@ -150,19 +145,15 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
self._is_init = False
|
self._is_init = False
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer.get_config()
|
return self._internal_optimizer.get_config()
|
||||||
|
|
||||||
def project_weights_to_r(self, force=False):
|
def project_weights_to_r(self, force=False):
|
||||||
"""helper method to normalize the weights to the R-ball.
|
"""Normalize the weights to the R-ball.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
force: True to normalize regardless of previous weight values.
|
force: True to normalize regardless of previous weight values.
|
||||||
False to check if weights > R-ball and only normalize then.
|
False to check if weights > R-ball and only normalize then.
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not self._is_init:
|
if not self._is_init:
|
||||||
raise Exception('This method must be called from within the optimizer\'s '
|
raise Exception('This method must be called from within the optimizer\'s '
|
||||||
|
@ -186,8 +177,8 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
input_dim: the input dimensionality for the weights
|
input_dim: the input dimensionality for the weights
|
||||||
output_dim the output dimensionality for the weights
|
output_dim the output dimensionality for the weights
|
||||||
|
|
||||||
Returns: noise in shape of layer's weights to be added to the weights.
|
Returns:
|
||||||
|
Noise in shape of layer's weights to be added to the weights.
|
||||||
"""
|
"""
|
||||||
if not self._is_init:
|
if not self._is_init:
|
||||||
raise Exception('This method must be called from within the optimizer\'s '
|
raise Exception('This method must be called from within the optimizer\'s '
|
||||||
|
@ -221,8 +212,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
'a valid distribution'.format(distribution))
|
'a valid distribution'.format(distribution))
|
||||||
|
|
||||||
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer.from_config(*args, **kwargs)
|
return self._internal_optimizer.from_config(*args, **kwargs)
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
|
@ -230,11 +220,10 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
from the _internal_optimizer instance.
|
from the _internal_optimizer instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name:
|
name:
|
||||||
|
|
||||||
Returns: attribute from Bolton if specified to come from self, else
|
Returns: attribute from Bolton if specified to come from self, else
|
||||||
from _internal_optimizer.
|
from _internal_optimizer.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if name == '_private_attributes' or name in self._private_attributes:
|
if name == '_private_attributes' or name in self._private_attributes:
|
||||||
return getattr(self, name)
|
return getattr(self, name)
|
||||||
|
@ -255,11 +244,8 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
Reroute everything else to the _internal_optimizer.
|
Reroute everything else to the _internal_optimizer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: attribute name
|
key: attribute name
|
||||||
value: attribute value
|
value: attribute value
|
||||||
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if key == '_private_attributes':
|
if key == '_private_attributes':
|
||||||
object.__setattr__(self, key, value)
|
object.__setattr__(self, key, value)
|
||||||
|
@ -269,44 +255,37 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
setattr(self._internal_optimizer, key, value)
|
setattr(self._internal_optimizer, key, value)
|
||||||
|
|
||||||
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
|
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
|
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
def get_updates(self, loss, params):
|
def get_updates(self, loss, params):
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
out = self._internal_optimizer.get_updates(loss, params)
|
out = self._internal_optimizer.get_updates(loss, params)
|
||||||
self.project_weights_to_r()
|
self.project_weights_to_r()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
|
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
|
||||||
self.project_weights_to_r()
|
self.project_weights_to_r()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
out = self._internal_optimizer.minimize(*args, **kwargs)
|
out = self._internal_optimizer.minimize(*args, **kwargs)
|
||||||
self.project_weights_to_r()
|
self.project_weights_to_r()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
|
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
|
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
"""
|
|
||||||
return self._internal_optimizer.get_gradients(*args, **kwargs)
|
return self._internal_optimizer.get_gradients(*args, **kwargs)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -326,8 +305,8 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
n_samples,
|
n_samples,
|
||||||
batch_size
|
batch_size
|
||||||
):
|
):
|
||||||
"""Entry point from context. Accepts required values for bolton method and
|
"""Accepts required values for bolton method from context entry point.
|
||||||
stores them on the optimizer for use throughout fitting.
|
Stores them on the optimizer for use throughout fitting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
noise_distribution: the noise distribution to pick.
|
noise_distribution: the noise distribution to pick.
|
||||||
|
@ -360,17 +339,15 @@ class Bolton(optimizer_v2.OptimizerV2):
|
||||||
|
|
||||||
def __exit__(self, *args):
|
def __exit__(self, *args):
|
||||||
"""Exit call from with statement.
|
"""Exit call from with statement.
|
||||||
used to
|
used to
|
||||||
|
|
||||||
1.reset the model and fit parameters passed to the optimizer
|
|
||||||
to enable the Bolton Privacy guarantees. These are reset to ensure
|
|
||||||
that any future calls to fit with the same instance of the optimizer
|
|
||||||
will properly error out.
|
|
||||||
|
|
||||||
2.call post-fit methods normalizing/projecting the model weights and
|
|
||||||
adding noise to the weights.
|
|
||||||
|
|
||||||
|
1.reset the model and fit parameters passed to the optimizer
|
||||||
|
to enable the Bolton Privacy guarantees. These are reset to ensure
|
||||||
|
that any future calls to fit with the same instance of the optimizer
|
||||||
|
will properly error out.
|
||||||
|
|
||||||
|
2.call post-fit methods normalizing/projecting the model weights and
|
||||||
|
adding noise to the weights.
|
||||||
"""
|
"""
|
||||||
self.project_weights_to_r(True)
|
self.project_weights_to_r(True)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2018, The TensorFlow Authors.
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -33,8 +33,7 @@ from privacy.bolton import optimizers as opt
|
||||||
|
|
||||||
|
|
||||||
class TestModel(Model):
|
class TestModel(Model):
|
||||||
"""
|
"""Bolton episilon-delta model.
|
||||||
Bolton episilon-delta model
|
|
||||||
Uses 4 key steps to achieve privacy guarantees:
|
Uses 4 key steps to achieve privacy guarantees:
|
||||||
1. Adds noise to weights after training (output perturbation).
|
1. Adds noise to weights after training (output perturbation).
|
||||||
2. Projects weights to R after each batch
|
2. Projects weights to R after each batch
|
||||||
|
@ -47,14 +46,15 @@ class TestModel(Model):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
|
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
|
||||||
"""
|
"""Constructor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_outputs: number of output neurons
|
n_outputs: number of output neurons
|
||||||
epsilon: level of privacy guarantee
|
epsilon: level of privacy guarantee
|
||||||
noise_distribution: distribution to pull weight perturbations from
|
noise_distribution: distribution to pull weight perturbations from
|
||||||
weights_initializer: initializer for weights
|
weights_initializer: initializer for weights
|
||||||
seed: random seed to use
|
seed: random seed to use
|
||||||
dtype: data type to use for tensors
|
dtype: data type to use for tensors
|
||||||
"""
|
"""
|
||||||
super(TestModel, self).__init__(name='bolton', dynamic=False)
|
super(TestModel, self).__init__(name='bolton', dynamic=False)
|
||||||
self.n_outputs = n_outputs
|
self.n_outputs = n_outputs
|
||||||
|
@ -76,40 +76,36 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
self.radius_constant = radius_constant
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
def radius(self):
|
def radius(self):
|
||||||
"""Radius of R-Ball (value to normalize weights to after each batch)
|
"""Radius, R, of the hypothesis space W.
|
||||||
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
Returns: radius
|
Returns: radius
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
|
||||||
|
|
||||||
def gamma(self):
|
def gamma(self):
|
||||||
""" Gamma strongly convex
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
|
||||||
Returns: gamma
|
|
||||||
|
|
||||||
"""
|
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
def beta(self, class_weight): # pylint: disable=unused-argument
|
def beta(self, class_weight): # pylint: disable=unused-argument
|
||||||
"""Beta smoothess
|
"""Smoothness, beta.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: the class weights used.
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
|
dimensionality is equal to the number of outputs.
|
||||||
Returns: Beta
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Beta
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
||||||
""" L lipchitz continuous
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
class_weight: class weights used
|
class_weight: class weights used
|
||||||
|
|
||||||
Returns: L
|
Returns: L
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
@ -121,11 +117,25 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def max_class_weight(self, class_weight, dtype=tf.float32):
|
def max_class_weight(self, class_weight, dtype=tf.float32):
|
||||||
|
"""the maximum weighting in class weights (max value) as a scalar tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
dtype: the data type for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
|
"""
|
||||||
if class_weight is None:
|
if class_weight is None:
|
||||||
return 1
|
return 1
|
||||||
raise NotImplementedError('')
|
raise NotImplementedError('')
|
||||||
|
|
||||||
def kernel_regularizer(self):
|
def kernel_regularizer(self):
|
||||||
|
"""Returns the kernel_regularizer to be used.
|
||||||
|
|
||||||
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
|
"""
|
||||||
return L1L2(l2=self.reg_lambda)
|
return L1L2(l2=self.reg_lambda)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,432 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": false
|
|
||||||
},
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import sys\n",
|
|
||||||
"sys.path.append('..')\n",
|
|
||||||
"import tensorflow as tf\n",
|
|
||||||
"from privacy.bolton import losses\n",
|
|
||||||
"from privacy.bolton import models"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"First, we will create a binary classification dataset with a single output dimension.\n",
|
|
||||||
"The samples for each label are repeated datapoints at different points in space."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": false,
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"(20, 2) (20, 1)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Parameters for dataset\n",
|
|
||||||
"n_samples = 10\n",
|
|
||||||
"input_dim = 2\n",
|
|
||||||
"n_outputs = 1\n",
|
|
||||||
"# Create binary classification dataset:\n",
|
|
||||||
"x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)), \n",
|
|
||||||
" tf.constant(1, tf.float32, (n_samples, input_dim))]\n",
|
|
||||||
"y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),\n",
|
|
||||||
" tf.constant(1, tf.float32, (n_samples, 1))]\n",
|
|
||||||
"x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)\n",
|
|
||||||
"print(x.shape, y.shape)\n",
|
|
||||||
"generator = tf.data.Dataset.from_tensor_slices((x, y))\n",
|
|
||||||
"generator = generator.batch(10)\n",
|
|
||||||
"generator = generator.shuffle(10)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"First, we will explore using the pre-built BoltonModel, which is a thin wrapper around a Keras Model using a single-layer neural network. It automatically uses the Bolton Optimizer which encompasses all the logic required for the Bolton Differential Privacy method.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Now, we will pick our optimizer and Strongly Convex Loss function. The loss must extend from StrongConvexMixin and implement the associated methods. Some existing loss functions are pre-implemented in bolton.loss"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"optimizer = tf.optimizers.SGD()\n",
|
|
||||||
"reg_lambda = 1\n",
|
|
||||||
"C = 1\n",
|
|
||||||
"radius_constant = 1\n",
|
|
||||||
"loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy to be 1; these are all tunable and their impact can be read in losses.StrongConvexBinaryCrossentropy. We then compile the model with the chosen optimizer and loss, which will automatically wrap the chosen optimizer with the Bolton Optimizer, ensuring the required components function as required for privacy guarantees."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"is_executing": false
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"bolt.compile(optimizer, loss)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"To fit the model, the optimizer will require additional information about the dataset and model. These parameters are:\n",
|
|
||||||
"1. the class_weights used\n",
|
|
||||||
"2. the number of samples in the dataset\n",
|
|
||||||
"3. the batch size\n",
|
|
||||||
"which the model will try to infer, if possible. If not, you will be required to pass these explicitly to the fit method.\n",
|
|
||||||
"As well, there are two privacy parameters than can be altered: \n",
|
|
||||||
"1. epsilon, a float\n",
|
|
||||||
"2. noise_distribution, a valid string indicating the distriution to use (must be implemented)\n",
|
|
||||||
"\n",
|
|
||||||
"The BoltonModel offers a helper method, .calculate_class_weight to aid in class_weight calculation."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"WARNING: Logging before flag parsing goes to stderr.\n",
|
|
||||||
"W0619 11:00:32.392859 4467058112 deprecation.py:323] From /Users/christopherchoo/PycharmProjects/privacy/venv/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
|
|
||||||
"Instructions for updating:\n",
|
|
||||||
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Train on 20 samples\n",
|
|
||||||
"Epoch 1/2\n",
|
|
||||||
"20/20 [==============================] - 0s 4ms/sample - loss: 0.8146\n",
|
|
||||||
"Epoch 2/2\n",
|
|
||||||
"20/20 [==============================] - 0s 94us/sample - loss: 0.5699\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<tensorflow.python.keras.callbacks.History at 0x10543d0f0>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 6,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# required parameters\n",
|
|
||||||
"class_weight = None # default, use .calculate_class_weight to specify other values\n",
|
|
||||||
"batch_size = None # default, if it cannot be inferred, specify this\n",
|
|
||||||
"n_samples = None # default, if it cannot be iferred, specify this\n",
|
|
||||||
"# privacy parameters\n",
|
|
||||||
"epsilon = 2\n",
|
|
||||||
"noise_distribution = 'laplace'\n",
|
|
||||||
"\n",
|
|
||||||
"bolt.fit(x, \n",
|
|
||||||
" y, \n",
|
|
||||||
" epsilon=epsilon, \n",
|
|
||||||
" class_weight=class_weight, \n",
|
|
||||||
" batch_size=batch_size, \n",
|
|
||||||
" n_samples=n_samples,\n",
|
|
||||||
" noise_distribution=noise_distribution,\n",
|
|
||||||
" epochs=2)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"We may also train a generator object, or try different optimizers and loss functions. Below, we will see that we must pass the number of samples as the fit method is unable to infer it for a generator."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"optimizer2 = tf.optimizers.Adam()\n",
|
|
||||||
"bolt.compile(optimizer2, loss)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Could not infer the number of samples. Please pass this in using n_samples.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# required parameters\n",
|
|
||||||
"class_weight = None # default, use .calculate_class_weight to specify other values\n",
|
|
||||||
"batch_size = None # default, if it cannot be inferred, specify this\n",
|
|
||||||
"n_samples = None # default, if it cannot be iferred, specify this\n",
|
|
||||||
"# privacy parameters\n",
|
|
||||||
"epsilon = 2\n",
|
|
||||||
"noise_distribution = 'laplace'\n",
|
|
||||||
"try:\n",
|
|
||||||
" bolt.fit(generator,\n",
|
|
||||||
" epsilon=epsilon, \n",
|
|
||||||
" class_weight=class_weight, \n",
|
|
||||||
" batch_size=batch_size, \n",
|
|
||||||
" n_samples=n_samples,\n",
|
|
||||||
" noise_distribution=noise_distribution,\n",
|
|
||||||
" verbose=0\n",
|
|
||||||
" )\n",
|
|
||||||
"except ValueError as e:\n",
|
|
||||||
" print(e)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"And now, re running with the parameter set."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"<tensorflow.python.keras.callbacks.History at 0x1267db4a8>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 9,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"n_samples = 20\n",
|
|
||||||
"bolt.fit(generator,\n",
|
|
||||||
" epsilon=epsilon, \n",
|
|
||||||
" class_weight=class_weight, \n",
|
|
||||||
" batch_size=batch_size, \n",
|
|
||||||
" n_samples=n_samples,\n",
|
|
||||||
" noise_distribution=noise_distribution,\n",
|
|
||||||
" verbose=0\n",
|
|
||||||
" )"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"You don't have to use the bolton model to use the Bolton method. There are only a few requirements:\n",
|
|
||||||
"1. make sure any requirements from the loss are implemented in the model.\n",
|
|
||||||
"2. instantiate the optimizer and use it as a context around your fit operation."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 10,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from privacy.bolton.optimizers import Bolton"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Here, we create our own model and setup the Bolton optimizer."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 11,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"class TestModel(tf.keras.Model):\n",
|
|
||||||
" def __init__(self, reg_layer, n_outputs=1):\n",
|
|
||||||
" super(TestModel, self).__init__(name='test')\n",
|
|
||||||
" self.output_layer = tf.keras.layers.Dense(n_outputs,\n",
|
|
||||||
" kernel_regularizer=reg_layer\n",
|
|
||||||
" )\n",
|
|
||||||
" \n",
|
|
||||||
" def call(self, inputs):\n",
|
|
||||||
" return self.output_layer(inputs)\n",
|
|
||||||
"\n",
|
|
||||||
"optimizer = tf.optimizers.SGD()\n",
|
|
||||||
"loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)\n",
|
|
||||||
"optimizer = Bolton(optimizer, loss)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"Now, we instantiate our model and check for 1. Since our loss requires L2 regularization over the kernel, we will pass it to the model."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 12,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"n_outputs = 1 # parameter for model and optimizer context.\n",
|
|
||||||
"test_model = TestModel(loss.kernel_regularizer(), n_outputs)\n",
|
|
||||||
"test_model.compile(optimizer, loss)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"We comply with 2., and use the Bolton Optimizer as a context around the fit method."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 14,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Train on 20 samples\n",
|
|
||||||
"Epoch 1/2\n",
|
|
||||||
"20/20 [==============================] - 0s 3ms/sample - loss: 0.9096\n",
|
|
||||||
"Epoch 2/2\n",
|
|
||||||
"20/20 [==============================] - 0s 430us/sample - loss: 0.5275\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# parameters for context\n",
|
|
||||||
"noise_distribution = 'laplace'\n",
|
|
||||||
"epsilon = 2\n",
|
|
||||||
"class_weights = 1 # Previosuly, the fit method auto-detected the class_weights.\n",
|
|
||||||
"# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.\n",
|
|
||||||
"n_samples = 20\n",
|
|
||||||
"batch_size = 5\n",
|
|
||||||
"\n",
|
|
||||||
"with optimizer(\n",
|
|
||||||
" noise_distribution=noise_distribution,\n",
|
|
||||||
" epsilon=epsilon,\n",
|
|
||||||
" layers=test_model.layers,\n",
|
|
||||||
" class_weights=class_weights, \n",
|
|
||||||
" n_samples=n_samples,\n",
|
|
||||||
" batch_size=batch_size\n",
|
|
||||||
") as _:\n",
|
|
||||||
" test_model.fit(x, y, batch_size=batch_size, epochs=2)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.6.8"
|
|
||||||
},
|
|
||||||
"pycharm": {
|
|
||||||
"stem_cell": {
|
|
||||||
"cell_type": "raw",
|
|
||||||
"source": [],
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 1
|
|
||||||
}
|
|
155
tutorials/bolton_tutorial.py
Normal file
155
tutorials/bolton_tutorial.py
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append('..')
|
||||||
|
import tensorflow as tf
|
||||||
|
from privacy.bolton import losses
|
||||||
|
from privacy.bolton import models
|
||||||
|
|
||||||
|
"""First, we will create a binary classification dataset with a single output
|
||||||
|
dimension. The samples for each label are repeated data points at different
|
||||||
|
points in space."""
|
||||||
|
# Parameters for dataset
|
||||||
|
n_samples = 10
|
||||||
|
input_dim = 2
|
||||||
|
n_outputs = 1
|
||||||
|
# Create binary classification dataset:
|
||||||
|
x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)),
|
||||||
|
tf.constant(1, tf.float32, (n_samples, input_dim))]
|
||||||
|
y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),
|
||||||
|
tf.constant(1, tf.float32, (n_samples, 1))]
|
||||||
|
x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)
|
||||||
|
print(x.shape, y.shape)
|
||||||
|
generator = tf.data.Dataset.from_tensor_slices((x, y))
|
||||||
|
generator = generator.batch(10)
|
||||||
|
generator = generator.shuffle(10)
|
||||||
|
"""First, we will explore using the pre - built BoltonModel, which is a thin
|
||||||
|
wrapper around a Keras Model using a single - layer neural network.
|
||||||
|
It automatically uses the Bolton Optimizer which encompasses all the logic
|
||||||
|
required for the Bolton Differential Privacy method."""
|
||||||
|
bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have.
|
||||||
|
"""Now, we will pick our optimizer and Strongly Convex Loss function. The loss
|
||||||
|
must extend from StrongConvexMixin and implement the associated methods.Some
|
||||||
|
existing loss functions are pre - implemented in bolton.loss"""
|
||||||
|
optimizer = tf.optimizers.SGD()
|
||||||
|
reg_lambda = 1
|
||||||
|
C = 1
|
||||||
|
radius_constant = 1
|
||||||
|
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
"""For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy
|
||||||
|
to be 1; these are all tunable and their impact can be read in losses.
|
||||||
|
StrongConvexBinaryCrossentropy.We then compile the model with the chosen
|
||||||
|
optimizer and loss, which will automatically wrap the chosen optimizer with the
|
||||||
|
Bolton Optimizer, ensuring the required components function as required for
|
||||||
|
privacy guarantees."""
|
||||||
|
bolt.compile(optimizer, loss)
|
||||||
|
"""To fit the model, the optimizer will require additional information about
|
||||||
|
the dataset and model.These parameters are:
|
||||||
|
1. the class_weights used
|
||||||
|
2. the number of samples in the dataset
|
||||||
|
3. the batch size which the model will try to infer, if possible. If not, you
|
||||||
|
will be required to pass these explicitly to the fit method.
|
||||||
|
|
||||||
|
As well, there are two privacy parameters than can be altered:
|
||||||
|
1. epsilon, a float
|
||||||
|
2. noise_distribution, a valid string indicating the distriution to use (must be
|
||||||
|
implemented)
|
||||||
|
|
||||||
|
The BoltonModel offers a helper method,.calculate_class_weight to aid in
|
||||||
|
class_weight calculation."""
|
||||||
|
# required parameters
|
||||||
|
class_weight = None # default, use .calculate_class_weight to specify other values
|
||||||
|
batch_size = None # default, if it cannot be inferred, specify this
|
||||||
|
n_samples = None # default, if it cannot be iferred, specify this
|
||||||
|
# privacy parameters
|
||||||
|
epsilon = 2
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
|
||||||
|
bolt.fit(x,
|
||||||
|
y,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
epochs=2)
|
||||||
|
"""We may also train a generator object, or try different optimizers and loss
|
||||||
|
functions. Below, we will see that we must pass the number of samples as the fit
|
||||||
|
method is unable to infer it for a generator."""
|
||||||
|
optimizer2 = tf.optimizers.Adam()
|
||||||
|
bolt.compile(optimizer2, loss)
|
||||||
|
# required parameters
|
||||||
|
class_weight = None # default, use .calculate_class_weight to specify other values
|
||||||
|
batch_size = None # default, if it cannot be inferred, specify this
|
||||||
|
n_samples = None # default, if it cannot be iferred, specify this
|
||||||
|
# privacy parameters
|
||||||
|
epsilon = 2
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
try:
|
||||||
|
bolt.fit(generator,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
verbose=0
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
"""And now, re running with the parameter set."""
|
||||||
|
n_samples = 20
|
||||||
|
bolt.fit(generator,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
verbose=0
|
||||||
|
)
|
||||||
|
"""You don't have to use the bolton model to use the Bolton method.
|
||||||
|
There are only a few requirements:
|
||||||
|
1. make sure any requirements from the loss are implemented in the model.
|
||||||
|
2. instantiate the optimizer and use it as a context around your fit operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from privacy.bolton.optimizers import Bolton
|
||||||
|
|
||||||
|
"""Here, we create our own model and setup the Bolton optimizer."""
|
||||||
|
|
||||||
|
class TestModel(tf.keras.Model):
|
||||||
|
def __init__(self, reg_layer, n_outputs=1):
|
||||||
|
super(TestModel, self).__init__(name='test')
|
||||||
|
self.output_layer = tf.keras.layers.Dense(n_outputs,
|
||||||
|
kernel_regularizer=reg_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.output_layer(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = tf.optimizers.SGD()
|
||||||
|
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
optimizer = Bolton(optimizer, loss)
|
||||||
|
"""Now, we instantiate our model and check for 1. Since our loss requires L2
|
||||||
|
regularization over the kernel, we will pass it to the model."""
|
||||||
|
n_outputs = 1 # parameter for model and optimizer context.
|
||||||
|
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
|
||||||
|
test_model.compile(optimizer, loss)
|
||||||
|
"""We comply with 2., and use the Bolton Optimizer as a context around the fit
|
||||||
|
method."""
|
||||||
|
# parameters for context
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
epsilon = 2
|
||||||
|
class_weights = 1 # Previously, the fit method auto-detected the class_weights.
|
||||||
|
# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.
|
||||||
|
n_samples = 20
|
||||||
|
batch_size = 5
|
||||||
|
|
||||||
|
with optimizer(
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
epsilon=epsilon,
|
||||||
|
layers=test_model.layers,
|
||||||
|
class_weights=class_weights,
|
||||||
|
n_samples=n_samples,
|
||||||
|
batch_size=batch_size
|
||||||
|
) as _:
|
||||||
|
test_model.fit(x, y, batch_size=batch_size, epochs=2)
|
Loading…
Reference in a new issue