forked from 626_privacy/tensorflow_privacy
Bolton created as optimizer with context manager usage.
Unit tests included. Additional loss functions TBD.
This commit is contained in:
parent
ec18db5ec5
commit
935d6e8480
6 changed files with 685 additions and 661 deletions
|
@ -102,7 +102,7 @@ class StrongConvexMixin:
|
|||
return tf.math.reduce_max(class_weight)
|
||||
|
||||
|
||||
class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
||||
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||
"""Strong Convex version of Huber loss using l2 weight regularization.
|
||||
"""
|
||||
|
||||
|
@ -112,7 +112,6 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
|||
radius_constant: float,
|
||||
delta: float,
|
||||
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||
name: str = 'huber',
|
||||
dtype=tf.float32):
|
||||
"""Constructor.
|
||||
|
||||
|
@ -137,13 +136,17 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
|||
raise ValueError('radius_constant: {0}, should be >= 0'.format(
|
||||
radius_constant
|
||||
))
|
||||
self.C = C
|
||||
if delta <= 0:
|
||||
raise ValueError('delta: {0}, should be >= 0'.format(
|
||||
delta
|
||||
))
|
||||
self.C = C # pylint: disable=invalid-name
|
||||
self.delta = delta
|
||||
self.radius_constant = radius_constant
|
||||
self.dtype = dtype
|
||||
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
super(StrongConvexHuber, self).__init__(
|
||||
delta=delta,
|
||||
name=name,
|
||||
name='huber',
|
||||
reduction=reduction,
|
||||
)
|
||||
|
||||
|
@ -151,26 +154,25 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
|||
"""Compute loss
|
||||
|
||||
Args:
|
||||
y_true: Ground truth values.
|
||||
y_true: Ground truth values. One
|
||||
y_pred: The predicted values.
|
||||
|
||||
Returns:
|
||||
Loss values per sample.
|
||||
"""
|
||||
# return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
|
||||
h = self._fn_kwargs['delta']
|
||||
h = self.delta
|
||||
z = y_pred * y_true
|
||||
one = tf.constant(1, dtype=self.dtype)
|
||||
four = tf.constant(4, dtype=self.dtype)
|
||||
|
||||
if z > one + h:
|
||||
return z - z
|
||||
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
|
||||
elif tf.math.abs(one - z) <= h:
|
||||
return one / (four * h) * tf.math.pow(one + h - z, 2)
|
||||
elif z < one - h:
|
||||
return one - z
|
||||
else:
|
||||
raise ValueError('')
|
||||
raise ValueError('') # shouldn't be possible to get here.
|
||||
|
||||
def radius(self):
|
||||
"""See super class.
|
||||
|
@ -186,7 +188,7 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
|
|||
"""See super class.
|
||||
"""
|
||||
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||
delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'],
|
||||
delta = _ops.convert_to_tensor_v2(self.delta,
|
||||
dtype=self.dtype
|
||||
)
|
||||
return self.C * max_class_weight / (delta *
|
||||
|
@ -250,7 +252,7 @@ class StrongConvexBinaryCrossentropy(
|
|||
radius_constant
|
||||
))
|
||||
self.dtype = dtype
|
||||
self.C = C
|
||||
self.C = C # pylint: disable=invalid-name
|
||||
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||
super(StrongConvexBinaryCrossentropy, self).__init__(
|
||||
reduction=reduction,
|
||||
|
@ -306,7 +308,7 @@ class StrongConvexBinaryCrossentropy(
|
|||
this loss function to be strongly convex.
|
||||
:return:
|
||||
"""
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
return L1L2(l2=self.reg_lambda/2)
|
||||
|
||||
|
||||
# class StrongConvexSparseCategoricalCrossentropy(
|
||||
|
|
|
@ -79,7 +79,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
|||
'reg_lambda': 1,
|
||||
'C': 1,
|
||||
'radius_constant': 1
|
||||
},
|
||||
}, # pylint: disable=invalid-name
|
||||
])
|
||||
def test_init_params(self, reg_lambda, C, radius_constant):
|
||||
"""Test initialization for given arguments
|
||||
|
@ -107,7 +107,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
|||
'reg_lambda': -1,
|
||||
'C': 1,
|
||||
'radius_constant': 1
|
||||
},
|
||||
}, # pylint: disable=invalid-name
|
||||
])
|
||||
def test_bad_init_params(self, reg_lambda, C, radius_constant):
|
||||
"""Test invalid domain for given params. Should return ValueError
|
||||
|
@ -180,7 +180,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
|||
'fn': 'kernel_regularizer',
|
||||
'init_args': [1, 1, 1],
|
||||
'args': [],
|
||||
'result': L1L2(l2=1),
|
||||
'result': L1L2(l2=0.5),
|
||||
},
|
||||
])
|
||||
def test_fns(self, init_args, fn, args, result):
|
||||
|
|
|
@ -23,8 +23,6 @@ from tensorflow.python.framework import ops as _ops
|
|||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from privacy.bolton.optimizer import Bolton
|
||||
|
||||
_accepted_distributions = ['laplace']
|
||||
|
||||
|
||||
class BoltonModel(Model):
|
||||
"""
|
||||
|
@ -41,41 +39,28 @@ class BoltonModel(Model):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_classes,
|
||||
# noise_distribution='laplace',
|
||||
n_outputs,
|
||||
seed=1,
|
||||
dtype=tf.float32
|
||||
):
|
||||
""" private constructor.
|
||||
|
||||
Args:
|
||||
n_classes: number of output classes to predict.
|
||||
epsilon: level of privacy guarantee
|
||||
noise_distribution: distribution to pull weight perturbations from
|
||||
weights_initializer: initializer for weights
|
||||
n_outputs: number of output classes to predict.
|
||||
seed: random seed to use
|
||||
dtype: data type to use for tensors
|
||||
"""
|
||||
|
||||
# if noise_distribution not in _accepted_distributions:
|
||||
# raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
|
||||
# 'distributions'.format(noise_distribution,
|
||||
# _accepted_distributions))
|
||||
# if epsilon <= 0:
|
||||
# raise ValueError('Detected epsilon: {0}. '
|
||||
# 'Valid range is 0 < epsilon <inf'.format(epsilon))
|
||||
# self.epsilon = epsilon
|
||||
super(BoltonModel, self).__init__(name='bolton', dynamic=False)
|
||||
self.n_classes = n_classes
|
||||
self.force = False
|
||||
# self.noise_distribution = noise_distribution
|
||||
if n_outputs <= 0:
|
||||
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
|
||||
n_outputs
|
||||
))
|
||||
self.n_outputs = n_outputs
|
||||
self.seed = seed
|
||||
self.__in_fit = False
|
||||
self._layers_instantiated = False
|
||||
# self._callback = MyCustomCallback()
|
||||
self._dtype = dtype
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs, training=False): # pylint: disable=arguments-differ
|
||||
"""Forward pass of network
|
||||
|
||||
Args:
|
||||
|
@ -87,37 +72,30 @@ class BoltonModel(Model):
|
|||
return self.output_layer(inputs)
|
||||
|
||||
def compile(self,
|
||||
optimizer='SGD',
|
||||
loss=None,
|
||||
optimizer,
|
||||
loss,
|
||||
metrics=None,
|
||||
loss_weights=None,
|
||||
sample_weight_mode=None,
|
||||
weighted_metrics=None,
|
||||
target_tensors=None,
|
||||
distribute=None,
|
||||
**kwargs):
|
||||
kernel_initializer=tf.initializers.GlorotUniform,
|
||||
**kwargs): # pylint: disable=arguments-differ
|
||||
"""See super class. Default optimizer used in Bolton method is SGD.
|
||||
|
||||
"""
|
||||
for key, val in StrongConvexMixin.__dict__.items():
|
||||
if callable(val) and getattr(loss, key, None) is None:
|
||||
raise ValueError("Please ensure you are passing a valid StrongConvex "
|
||||
"loss that has all the required methods "
|
||||
"implemented. "
|
||||
"Required method: {0} not found".format(key))
|
||||
if not isinstance(loss, StrongConvexMixin):
|
||||
raise ValueError("loss function must be a Strongly Convex and therefore "
|
||||
"extend the StrongConvexMixin.")
|
||||
if not self._layers_instantiated: # compile may be called multiple times
|
||||
kernel_intiializer = kwargs.get('kernel_initializer',
|
||||
tf.initializers.GlorotUniform)
|
||||
# for instance, if the input/outputs are not defined until fit.
|
||||
self.output_layer = tf.keras.layers.Dense(
|
||||
self.n_classes,
|
||||
self.n_outputs,
|
||||
kernel_regularizer=loss.kernel_regularizer(),
|
||||
kernel_initializer=kernel_intiializer(),
|
||||
kernel_initializer=kernel_initializer(),
|
||||
)
|
||||
# if we don't do regularization here, we require the user to
|
||||
# re-instantiate the model each time they want to change the penalty
|
||||
# weighting
|
||||
self._layers_instantiated = True
|
||||
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
|
||||
if not isinstance(optimizer, Bolton):
|
||||
optimizer = optimizers.get(optimizer)
|
||||
optimizer = Bolton(optimizer, loss)
|
||||
|
@ -133,69 +111,16 @@ class BoltonModel(Model):
|
|||
**kwargs
|
||||
)
|
||||
|
||||
# def _post_fit(self, x, n_samples):
|
||||
# """Implements 1-time weight changes needed for Bolton method.
|
||||
# In this case, specifically implements the noise addition
|
||||
# assuming a strongly convex function.
|
||||
#
|
||||
# Args:
|
||||
# x: inputs
|
||||
# n_samples: number of samples in the inputs. In case the number
|
||||
# cannot be readily determined by inspecting x.
|
||||
#
|
||||
# Returns:
|
||||
#
|
||||
# """
|
||||
# data_size = None
|
||||
# if n_samples is not None:
|
||||
# data_size = n_samples
|
||||
# elif hasattr(x, 'shape'):
|
||||
# data_size = x.shape[0]
|
||||
# elif hasattr(x, "__len__"):
|
||||
# data_size = len(x)
|
||||
# elif data_size is None:
|
||||
# if n_samples is None:
|
||||
# raise ValueError("Unable to detect the number of training "
|
||||
# "samples and n_smaples was None. "
|
||||
# "either pass a dataset with a .shape or "
|
||||
# "__len__ attribute or explicitly pass the "
|
||||
# "number of samples as n_smaples.")
|
||||
# for layer in self.layers:
|
||||
# # layer.kernel = layer.kernel + self._get_noise(
|
||||
# # data_size
|
||||
# # )
|
||||
# input_dim = layer.kernel.numpy().shape[0]
|
||||
# layer.kernel = layer.kernel + self.optimizer.get_noise(
|
||||
# self.loss,
|
||||
# data_size,
|
||||
# input_dim,
|
||||
# self.n_classes,
|
||||
# self.class_weight
|
||||
# )
|
||||
|
||||
def fit(self,
|
||||
x=None,
|
||||
y=None,
|
||||
batch_size=None,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
callbacks=None,
|
||||
validation_split=0.0,
|
||||
validation_data=None,
|
||||
shuffle=True,
|
||||
class_weight=None,
|
||||
sample_weight=None,
|
||||
initial_epoch=0,
|
||||
steps_per_epoch=None,
|
||||
validation_steps=None,
|
||||
validation_freq=1,
|
||||
max_queue_size=10,
|
||||
workers=1,
|
||||
use_multiprocessing=False,
|
||||
n_samples=None,
|
||||
epsilon=2,
|
||||
noise_distribution='laplace',
|
||||
**kwargs):
|
||||
steps_per_epoch=None,
|
||||
**kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
|
||||
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
|
||||
Requirements are as follows:
|
||||
|
@ -207,92 +132,101 @@ class BoltonModel(Model):
|
|||
|
||||
Args:
|
||||
n_samples: the number of individual samples in x.
|
||||
epsilon: privacy parameter, which trades off between utility an privacy.
|
||||
See the bolton paper for more description.
|
||||
noise_distribution: the distribution to pull noise from.
|
||||
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||
whose dim == n_classes.
|
||||
|
||||
Returns:
|
||||
See the super method for descriptions on the rest of the arguments.
|
||||
|
||||
"""
|
||||
self.__in_fit = True
|
||||
# cb = [self.optimizer.callbacks]
|
||||
# if callbacks is not None:
|
||||
# cb.extend(callbacks)
|
||||
# callbacks = cb
|
||||
if class_weight is None:
|
||||
class_weight = self.calculate_class_weights(class_weight)
|
||||
# self.class_weight = class_weight
|
||||
if n_samples is not None:
|
||||
data_size = n_samples
|
||||
elif hasattr(x, 'shape'):
|
||||
data_size = x.shape[0]
|
||||
elif hasattr(x, "__len__"):
|
||||
data_size = len(x)
|
||||
else:
|
||||
data_size = None
|
||||
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
||||
steps_per_epoch,
|
||||
x
|
||||
)
|
||||
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
||||
# initial value when passed to super().fit()
|
||||
if batch_size_ is None:
|
||||
raise ValueError('batch_size: {0} is an '
|
||||
'invalid value'.format(batch_size_))
|
||||
with self.optimizer(noise_distribution,
|
||||
epsilon,
|
||||
self.layers,
|
||||
class_weight,
|
||||
n_samples,
|
||||
self.n_classes,
|
||||
) as optim:
|
||||
data_size,
|
||||
self.n_outputs,
|
||||
batch_size_,
|
||||
) as _:
|
||||
out = super(BoltonModel, self).fit(x=x,
|
||||
y=y,
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
validation_split=validation_split,
|
||||
validation_data=validation_data,
|
||||
shuffle=shuffle,
|
||||
class_weight=class_weight,
|
||||
sample_weight=sample_weight,
|
||||
initial_epoch=initial_epoch,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq,
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
**kwargs
|
||||
)
|
||||
return out
|
||||
|
||||
def fit_generator(self,
|
||||
generator,
|
||||
steps_per_epoch=None,
|
||||
epochs=1,
|
||||
verbose=1,
|
||||
callbacks=None,
|
||||
validation_data=None,
|
||||
validation_steps=None,
|
||||
validation_freq=1,
|
||||
class_weight=None,
|
||||
max_queue_size=10,
|
||||
workers=1,
|
||||
use_multiprocessing=False,
|
||||
shuffle=True,
|
||||
initial_epoch=0,
|
||||
n_samples=None
|
||||
):
|
||||
noise_distribution='laplace',
|
||||
epsilon=2,
|
||||
n_samples=None,
|
||||
steps_per_epoch=None,
|
||||
**kwargs
|
||||
): # pylint: disable=arguments-differ
|
||||
"""
|
||||
This method is the same as fit except for when the passed dataset
|
||||
is a generator. See super method and fit for more details.
|
||||
Args:
|
||||
n_samples: number of individual samples in x
|
||||
noise_distribution: the distribution to get noise from.
|
||||
epsilon: privacy parameter, which trades off utility and privacy. See
|
||||
Bolton paper for more description.
|
||||
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||
whose dim == n_classes.
|
||||
|
||||
See the super method for descriptions on the rest of the arguments.
|
||||
"""
|
||||
if class_weight is None:
|
||||
class_weight = self.calculate_class_weights(class_weight)
|
||||
self.class_weight = class_weight
|
||||
out = super(BoltonModel, self).fit_generator(
|
||||
generator,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
epochs=epochs,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
validation_data=validation_data,
|
||||
validation_steps=validation_steps,
|
||||
validation_freq=validation_freq,
|
||||
class_weight=class_weight,
|
||||
max_queue_size=max_queue_size,
|
||||
workers=workers,
|
||||
use_multiprocessing=use_multiprocessing,
|
||||
shuffle=shuffle,
|
||||
initial_epoch=initial_epoch
|
||||
)
|
||||
if not self.__in_fit:
|
||||
self._post_fit(generator, n_samples)
|
||||
if n_samples is not None:
|
||||
data_size = n_samples
|
||||
elif hasattr(generator, 'shape'):
|
||||
data_size = generator.shape[0]
|
||||
elif hasattr(generator, "__len__"):
|
||||
data_size = len(generator)
|
||||
else:
|
||||
data_size = None
|
||||
batch_size = self._validate_or_infer_batch_size(None,
|
||||
steps_per_epoch,
|
||||
generator
|
||||
)
|
||||
with self.optimizer(noise_distribution,
|
||||
epsilon,
|
||||
self.layers,
|
||||
class_weight,
|
||||
data_size,
|
||||
self.n_outputs,
|
||||
batch_size
|
||||
) as _:
|
||||
out = super(BoltonModel, self).fit_generator(
|
||||
generator,
|
||||
class_weight=class_weight,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
**kwargs
|
||||
)
|
||||
return out
|
||||
|
||||
def calculate_class_weights(self,
|
||||
|
@ -336,7 +270,7 @@ class BoltonModel(Model):
|
|||
"class_weights=%s" % class_weights)
|
||||
elif class_weights is not None:
|
||||
if num_classes is None:
|
||||
raise ValueError("You must pass a value for num_classes if"
|
||||
raise ValueError("You must pass a value for num_classes if "
|
||||
"creating an array of class_weights")
|
||||
# performing class weight calculation
|
||||
if class_weights is None:
|
||||
|
@ -357,195 +291,9 @@ class BoltonModel(Model):
|
|||
"1D array".format(class_weights.shape))
|
||||
if class_weights.shape[0] != num_classes:
|
||||
raise ValueError(
|
||||
"Detected array length: {0} instead of: {1}".format(
|
||||
class_weights.shape[0],
|
||||
num_classes
|
||||
)
|
||||
"Detected array length: {0} instead of: {1}".format(
|
||||
class_weights.shape[0],
|
||||
num_classes
|
||||
)
|
||||
)
|
||||
return class_weights
|
||||
|
||||
# def _project_weights_to_r(self, r, force=False):
|
||||
# """helper method to normalize the weights to the R-ball.
|
||||
#
|
||||
# Args:
|
||||
# r: radius of "R-Ball". Scalar to normalize to.
|
||||
# force: True to normalize regardless of previous weight values.
|
||||
# False to check if weights > R-ball and only normalize then.
|
||||
#
|
||||
# Returns:
|
||||
#
|
||||
# """
|
||||
# for layer in self.layers:
|
||||
# weight_norm = tf.norm(layer.kernel, axis=0)
|
||||
# if force:
|
||||
# layer.kernel = layer.kernel / (weight_norm / r)
|
||||
# elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
|
||||
# layer.kernel = layer.kernel / (weight_norm / r)
|
||||
|
||||
# def _get_noise(self, distribution, data_size):
|
||||
# """Sample noise to be added to weights for privacy guarantee
|
||||
#
|
||||
# Args:
|
||||
# distribution: the distribution type to pull noise from
|
||||
# data_size: the number of samples
|
||||
#
|
||||
# Returns: noise in shape of layer's weights to be added to the weights.
|
||||
#
|
||||
# """
|
||||
# distribution = distribution.lower()
|
||||
# input_dim = self.layers[0].kernel.numpy().shape[0]
|
||||
# loss = self.loss
|
||||
# if distribution == _accepted_distributions[0]: # laplace
|
||||
# per_class_epsilon = self.epsilon / (self.n_classes)
|
||||
# l2_sensitivity = (2 *
|
||||
# loss.lipchitz_constant(self.class_weight)) / \
|
||||
# (loss.gamma() * data_size)
|
||||
# unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
|
||||
# mean=0,
|
||||
# seed=1,
|
||||
# stddev=1.0,
|
||||
# dtype=self._dtype)
|
||||
# unit_vector = unit_vector / tf.math.sqrt(
|
||||
# tf.reduce_sum(tf.math.square(unit_vector), axis=0)
|
||||
# )
|
||||
#
|
||||
# beta = l2_sensitivity / per_class_epsilon
|
||||
# alpha = input_dim # input_dim
|
||||
# gamma = tf.random.gamma([self.n_classes],
|
||||
# alpha,
|
||||
# beta=1 / beta,
|
||||
# seed=1,
|
||||
# dtype=self._dtype
|
||||
# )
|
||||
# return unit_vector * gamma
|
||||
# raise NotImplementedError('Noise distribution: {0} is not '
|
||||
# 'a valid distribution'.format(distribution))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import tensorflow as tf
|
||||
|
||||
import os
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'
|
||||
|
||||
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
|
||||
origin=_URL,
|
||||
extract=True)
|
||||
|
||||
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
|
||||
BUFFER_SIZE = 400
|
||||
BATCH_SIZE = 1
|
||||
IMG_WIDTH = 256
|
||||
IMG_HEIGHT = 256
|
||||
|
||||
|
||||
def load(image_file):
|
||||
image = tf.io.read_file(image_file)
|
||||
image = tf.image.decode_jpeg(image)
|
||||
|
||||
w = tf.shape(image)[1]
|
||||
|
||||
w = w // 2
|
||||
real_image = image[:, :w, :]
|
||||
input_image = image[:, w:, :]
|
||||
|
||||
input_image = tf.cast(input_image, tf.float32)
|
||||
real_image = tf.cast(real_image, tf.float32)
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
inp, re = load(PATH + 'train/100.jpg')
|
||||
# casting to int for matplotlib to show the image
|
||||
plt.figure()
|
||||
plt.imshow(inp / 255.0)
|
||||
plt.figure()
|
||||
plt.imshow(re / 255.0)
|
||||
|
||||
|
||||
def resize(input_image, real_image, height, width):
|
||||
input_image = tf.image.resize(input_image, [height, width],
|
||||
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
real_image = tf.image.resize(real_image, [height, width],
|
||||
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
def random_crop(input_image, real_image):
|
||||
stacked_image = tf.stack([input_image, real_image], axis=0)
|
||||
cropped_image = tf.image.random_crop(
|
||||
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
|
||||
|
||||
return cropped_image[0], cropped_image[1]
|
||||
|
||||
|
||||
def normalize(input_image, real_image):
|
||||
input_image = (input_image / 127.5) - 1
|
||||
real_image = (real_image / 127.5) - 1
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
@tf.function()
|
||||
def random_jitter(input_image, real_image):
|
||||
# resizing to 286 x 286 x 3
|
||||
input_image, real_image = resize(input_image, real_image, 286, 286)
|
||||
|
||||
# randomly cropping to 256 x 256 x 3
|
||||
input_image, real_image = random_crop(input_image, real_image)
|
||||
|
||||
if tf.random.uniform(()) > 0.5:
|
||||
# random mirroring
|
||||
input_image = tf.image.flip_left_right(input_image)
|
||||
real_image = tf.image.flip_left_right(real_image)
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
def load_image_train(image_file):
|
||||
input_image, real_image = load(image_file)
|
||||
input_image, real_image = random_jitter(input_image, real_image)
|
||||
input_image, real_image = normalize(input_image, real_image)
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
def load_image_test(image_file):
|
||||
input_image, real_image = load(image_file)
|
||||
input_image, real_image = resize(input_image, real_image,
|
||||
IMG_HEIGHT, IMG_WIDTH)
|
||||
input_image, real_image = normalize(input_image, real_image)
|
||||
|
||||
return input_image, real_image
|
||||
|
||||
|
||||
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
|
||||
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
|
||||
train_dataset = train_dataset.map(load_image_train,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
train_dataset = train_dataset.batch(1)
|
||||
# steps_per_epoch = training_utils.infer_steps_for_dataset(
|
||||
# train_dataset, None, epochs=1, steps_name='steps')
|
||||
|
||||
# for batch in train_dataset:
|
||||
# print(batch[1].shape)
|
||||
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
|
||||
# shuffling so that for every epoch a different image is generated
|
||||
# to predict and display the progress of our model.
|
||||
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
|
||||
test_dataset = test_dataset.map(load_image_test)
|
||||
test_dataset = test_dataset.batch(1)
|
||||
|
||||
be = BoltonModel(3, 2)
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from privacy.bolton import loss
|
||||
|
||||
test = adam.Adam()
|
||||
l = loss.StrongConvexBinaryCrossentropy(1, 2, 1)
|
||||
be.compile(test, l)
|
||||
print("Eager exeuction: {0}".format(tf.executing_eagerly()))
|
||||
be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1)
|
||||
|
|
|
@ -19,25 +19,22 @@ from __future__ import print_function
|
|||
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from privacy.bolton import model
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from absl.testing import parameterized
|
||||
from absl.testing import absltest
|
||||
from tensorflow.python.keras.regularizers import L1L2
|
||||
|
||||
from absl.testing import parameterized
|
||||
from privacy.bolton import model
|
||||
from privacy.bolton.optimizer import Bolton
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
|
||||
class TestLoss(losses.Loss, StrongConvexMixin):
|
||||
"""Test loss function for testing Bolton model"""
|
||||
def __init__(self, reg_lambda, C, radius_constant, name='test'):
|
||||
super(TestLoss, self).__init__(name=name)
|
||||
self.reg_lambda = reg_lambda
|
||||
self.C = C
|
||||
self.C = C # pylint: disable=invalid-name
|
||||
self.radius_constant = radius_constant
|
||||
|
||||
def radius(self):
|
||||
|
@ -78,13 +75,17 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def call(self, val0, val1):
|
||||
def call(self, y_true, y_pred):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
|
||||
return 0.5 * tf.reduce_sum(
|
||||
tf.math.squared_difference(y_true, y_pred),
|
||||
axis=1
|
||||
)
|
||||
|
||||
def max_class_weight(self, class_weight):
|
||||
if class_weight is None:
|
||||
return 1
|
||||
raise ValueError('')
|
||||
|
||||
def kernel_regularizer(self):
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
|
@ -116,125 +117,91 @@ class InitTests(keras_parameterized.TestCase):
|
|||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'normal',
|
||||
'n_classes': 1,
|
||||
'epsilon': 1,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 1
|
||||
'n_outputs': 1,
|
||||
},
|
||||
{'testcase_name': 'extreme range',
|
||||
'n_classes': 5,
|
||||
'epsilon': 0.1,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 10
|
||||
},
|
||||
{'testcase_name': 'extreme range2',
|
||||
'n_classes': 50,
|
||||
'epsilon': 10,
|
||||
'noise_distribution': 'laplace',
|
||||
'seed': 100
|
||||
{'testcase_name': 'many outputs',
|
||||
'n_outputs': 100,
|
||||
},
|
||||
])
|
||||
def test_init_params(
|
||||
self, n_classes, epsilon, noise_distribution, seed):
|
||||
def test_init_params(self, n_outputs):
|
||||
"""test initialization of BoltonModel
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
"""
|
||||
# test valid domains for each variable
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
seed
|
||||
)
|
||||
self.assertIsInstance(clf, model.Bolton)
|
||||
clf = model.BoltonModel(n_outputs)
|
||||
self.assertIsInstance(clf, model.BoltonModel)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'invalid noise',
|
||||
'n_classes': 1,
|
||||
'epsilon': 1,
|
||||
'noise_distribution': 'not_valid',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'invalid epsilon',
|
||||
'n_classes': 1,
|
||||
'epsilon': -1,
|
||||
'noise_distribution': 'laplace',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
{'testcase_name': 'invalid n_outputs',
|
||||
'n_outputs': -1,
|
||||
},
|
||||
])
|
||||
def test_bad_init_params(
|
||||
self,
|
||||
n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer):
|
||||
def test_bad_init_params(self, n_outputs):
|
||||
"""test bad initializations of BoltonModel that should raise errors
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
"""
|
||||
# test invalid domains for each variable, especially noise
|
||||
seed = 1
|
||||
with self.assertRaises(ValueError):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer,
|
||||
seed
|
||||
)
|
||||
model.BoltonModel(n_outputs)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'string compile',
|
||||
'n_classes': 1,
|
||||
'n_outputs': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': 'adam',
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'test compile',
|
||||
'n_classes': 100,
|
||||
'n_outputs': 100,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': TestOptimizer(),
|
||||
'weights_initializer': tf.initializers.GlorotUniform(),
|
||||
},
|
||||
{'testcase_name': 'invalid weights initializer',
|
||||
'n_classes': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': TestOptimizer(),
|
||||
'weights_initializer': 'not_valid',
|
||||
},
|
||||
])
|
||||
def test_compile(self, n_classes, loss, optimizer, weights_initializer):
|
||||
def test_compile(self, n_outputs, loss, optimizer):
|
||||
"""test compilation of BoltonModel
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
loss: instantiated TestLoss instance
|
||||
optimizer: instanced TestOptimizer instance
|
||||
"""
|
||||
# test compilation of valid tf.optimizer and tf.loss
|
||||
epsilon = 1
|
||||
noise_distribution = 'laplace'
|
||||
with self.cached_session():
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer
|
||||
)
|
||||
clf = model.BoltonModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
self.assertEqual(clf.loss, loss)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'Not strong loss',
|
||||
'n_classes': 1,
|
||||
'n_outputs': 1,
|
||||
'loss': losses.BinaryCrossentropy(),
|
||||
'optimizer': 'adam',
|
||||
},
|
||||
{'testcase_name': 'Not valid optimizer',
|
||||
'n_classes': 1,
|
||||
'n_outputs': 1,
|
||||
'loss': TestLoss(1, 1, 1),
|
||||
'optimizer': 'ada',
|
||||
}
|
||||
])
|
||||
def test_bad_compile(self, n_classes, loss, optimizer):
|
||||
def test_bad_compile(self, n_outputs, loss, optimizer):
|
||||
"""test bad compilations of BoltonModel that should raise errors
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
loss: instantiated TestLoss instance
|
||||
optimizer: instanced TestOptimizer instance
|
||||
"""
|
||||
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
||||
epsilon = 1
|
||||
noise_distribution = 'laplace'
|
||||
weights_initializer = tf.initializers.GlorotUniform()
|
||||
with self.cached_session():
|
||||
with self.assertRaises((ValueError, AttributeError)):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
noise_distribution,
|
||||
weights_initializer
|
||||
)
|
||||
clf = model.BoltonModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
|
||||
|
||||
def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
|
||||
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
||||
"""
|
||||
Creates a categorically encoded dataset (y is categorical).
|
||||
returns the specified dataset either as a static array or as a generator.
|
||||
|
@ -245,10 +212,9 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
|
|||
n_samples: number of rows
|
||||
input_dim: input dimensionality
|
||||
n_classes: output dimensionality
|
||||
t: one of 'train', 'val', 'test'
|
||||
generator: False for array, True for generator
|
||||
Returns:
|
||||
X as (n_samples, input_dim), Y as (n_samples, n_classes)
|
||||
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
||||
"""
|
||||
x_stack = []
|
||||
y_stack = []
|
||||
|
@ -269,25 +235,39 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
|
|||
|
||||
def _do_fit(n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
n_outputs,
|
||||
epsilon,
|
||||
generator,
|
||||
batch_size,
|
||||
reset_n_samples,
|
||||
optimizer,
|
||||
loss,
|
||||
callbacks,
|
||||
distribution='laplace'):
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon,
|
||||
distribution
|
||||
)
|
||||
"""Helper to instantiate necessary components for fitting and perform a model
|
||||
fit.
|
||||
|
||||
Args:
|
||||
n_samples: number of samples in dataset
|
||||
input_dim: the sample dimensionality
|
||||
n_outputs: number of output neurons
|
||||
epsilon: privacy parameter
|
||||
generator: True to create a generator, False to use an iterator
|
||||
batch_size: batch_size to use
|
||||
reset_n_samples: True to set _samples to None prior to fitting.
|
||||
False does nothing
|
||||
optimizer: instance of TestOptimizer
|
||||
loss: instance of TestLoss
|
||||
distribution: distribution to get noise from.
|
||||
|
||||
Returns: BoltonModel instsance
|
||||
"""
|
||||
clf = model.BoltonModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
if generator:
|
||||
x = _cat_dataset(
|
||||
n_samples,
|
||||
input_dim,
|
||||
n_classes,
|
||||
n_outputs,
|
||||
generator=generator
|
||||
)
|
||||
y = None
|
||||
|
@ -295,25 +275,20 @@ def _do_fit(n_samples,
|
|||
x = x.shuffle(n_samples//2)
|
||||
batch_size = None
|
||||
else:
|
||||
x, y = _cat_dataset(n_samples, input_dim, n_classes, generator=generator)
|
||||
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
|
||||
if reset_n_samples:
|
||||
n_samples = None
|
||||
|
||||
if callbacks is not None:
|
||||
callbacks = [callbacks]
|
||||
clf.fit(x,
|
||||
y,
|
||||
batch_size=batch_size,
|
||||
n_samples=n_samples,
|
||||
callbacks=callbacks
|
||||
noise_distribution=distribution,
|
||||
epsilon=epsilon
|
||||
)
|
||||
return clf
|
||||
|
||||
|
||||
class TestCallback(tf.keras.callbacks.Callback):
|
||||
pass
|
||||
|
||||
|
||||
class FitTests(keras_parameterized.TestCase):
|
||||
"""Test cases for keras model fitting"""
|
||||
|
||||
|
@ -322,27 +297,29 @@ class FitTests(keras_parameterized.TestCase):
|
|||
{'testcase_name': 'iterator fit',
|
||||
'generator': False,
|
||||
'reset_n_samples': True,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'iterator fit no samples',
|
||||
'generator': False,
|
||||
'reset_n_samples': True,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'generator fit',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': None
|
||||
},
|
||||
{'testcase_name': 'with callbacks',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': TestCallback()
|
||||
},
|
||||
])
|
||||
def test_fit(self, generator, reset_n_samples, callbacks):
|
||||
def test_fit(self, generator, reset_n_samples):
|
||||
"""Tests fitting of BoltonModel
|
||||
|
||||
Args:
|
||||
generator: True for generator test, False for iterator test.
|
||||
reset_n_samples: True to reset the n_samples to None, False does nothing
|
||||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
optimizer = Bolton(TestOptimizer(), loss)
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
|
@ -358,28 +335,27 @@ class FitTests(keras_parameterized.TestCase):
|
|||
reset_n_samples,
|
||||
optimizer,
|
||||
loss,
|
||||
callbacks
|
||||
)
|
||||
self.assertEqual(hasattr(clf, 'layers'), True)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'generator fit',
|
||||
'generator': True,
|
||||
'reset_n_samples': False,
|
||||
'callbacks': None
|
||||
},
|
||||
])
|
||||
def test_fit_gen(self, generator, reset_n_samples, callbacks):
|
||||
def test_fit_gen(self, generator):
|
||||
"""Tests the fit_generator method of BoltonModel
|
||||
|
||||
Args:
|
||||
generator: True to test with a generator dataset
|
||||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
batch_size = 1
|
||||
n_samples = 10
|
||||
clf = model.Bolton(n_classes,
|
||||
epsilon
|
||||
)
|
||||
clf = model.BoltonModel(n_classes)
|
||||
clf.compile(optimizer, loss)
|
||||
x = _cat_dataset(
|
||||
n_samples,
|
||||
|
@ -405,6 +381,14 @@ class FitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_bad_fit(self, generator, reset_n_samples, distribution):
|
||||
"""Tests fitting with invalid parameters, which should raise an error
|
||||
|
||||
Args:
|
||||
generator: True to test with generator, False is iterator
|
||||
reset_n_samples: True to reset the n_samples param to None prior to
|
||||
passing it to fit
|
||||
distribution: distribution to get noise from.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
|
@ -423,7 +407,6 @@ class FitTests(keras_parameterized.TestCase):
|
|||
reset_n_samples,
|
||||
optimizer,
|
||||
loss,
|
||||
None,
|
||||
distribution
|
||||
)
|
||||
|
||||
|
@ -450,7 +433,15 @@ class FitTests(keras_parameterized.TestCase):
|
|||
num_classes,
|
||||
result
|
||||
):
|
||||
clf = model.Bolton(1, 1)
|
||||
"""Tests the BOltonModel calculate_class_weights method
|
||||
|
||||
Args:
|
||||
class_weights: the class_weights to use
|
||||
class_counts: count of number of samples for each class
|
||||
num_classes: number of outputs neurons
|
||||
result: expected result
|
||||
"""
|
||||
clf = model.BoltonModel(1, 1)
|
||||
expected = clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes
|
||||
|
@ -508,12 +499,21 @@ class FitTests(keras_parameterized.TestCase):
|
|||
num_classes,
|
||||
err_msg
|
||||
):
|
||||
clf = model.Bolton(1, 1)
|
||||
with self.assertRaisesRegexp(ValueError, err_msg):
|
||||
expected = clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes
|
||||
)
|
||||
"""Tests the BOltonModel calculate_class_weights method with invalid params
|
||||
which should raise the expected errors.
|
||||
|
||||
Args:
|
||||
class_weights: the class_weights to use
|
||||
class_counts: count of number of samples for each class
|
||||
num_classes: number of outputs neurons
|
||||
result: expected result
|
||||
"""
|
||||
clf = model.BoltonModel(1, 1)
|
||||
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||
clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -19,9 +19,74 @@ from __future__ import print_function
|
|||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python import ops as _ops
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
|
||||
_accepted_distributions = ['laplace']
|
||||
_accepted_distributions = ['laplace'] # implemented distributions for noising
|
||||
|
||||
|
||||
class GammaBetaDecreasingStep(
|
||||
optimizer_v2.learning_rate_schedule.LearningRateSchedule
|
||||
):
|
||||
"""
|
||||
Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step)
|
||||
at each step. A required step for privacy guarantees.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.is_init = False
|
||||
self.beta = None
|
||||
self.gamma = None
|
||||
|
||||
def __call__(self, step):
|
||||
"""
|
||||
returns the learning rate
|
||||
Args:
|
||||
step: the current iteration number
|
||||
Returns:
|
||||
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
|
||||
the Bolton privacy requirements.
|
||||
"""
|
||||
if not self.is_init:
|
||||
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
|
||||
'This is performed automatically by using the '
|
||||
'{1} as a context manager, '
|
||||
'as desired'.format(self.__class__.__name__,
|
||||
Bolton.__class__.__name__
|
||||
)
|
||||
)
|
||||
dtype = self.beta.dtype
|
||||
one = tf.constant(1, dtype)
|
||||
return tf.math.minimum(tf.math.reduce_min(one/self.beta),
|
||||
one/(self.gamma*math_ops.cast(step, dtype))
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
"""
|
||||
config to setup the learning rate scheduler.
|
||||
"""
|
||||
return {'beta': self.beta, 'gamma': self.gamma}
|
||||
|
||||
def initialize(self, beta, gamma):
|
||||
"""setup the learning rate scheduler with the beta and gamma values provided
|
||||
by the loss function. Meant to be used with .fit as the loss params may
|
||||
depend on values passed to fit.
|
||||
|
||||
Args:
|
||||
beta: Smoothness value. See StrongConvexMixin
|
||||
gamma: Strong Convexity parameter. See StrongConvexMixin.
|
||||
"""
|
||||
self.is_init = True
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def de_initialize(self):
|
||||
"""De initialize the scheduler after fitting, in case another fit call has
|
||||
different loss parameters.
|
||||
"""
|
||||
self.is_init = False
|
||||
self.beta = None
|
||||
self.gamma = None
|
||||
|
||||
|
||||
class Bolton(optimizer_v2.OptimizerV2):
|
||||
|
@ -31,11 +96,24 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
passed, "Bolton" enables the bolton model to control the learning rate
|
||||
based on the strongly convex loss.
|
||||
|
||||
To use the Bolton method, you must:
|
||||
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
|
||||
2. use it as a context manager around your .fit method internals.
|
||||
|
||||
This can be accomplished by the following:
|
||||
optimizer = tf.optimizers.SGD()
|
||||
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
|
||||
bolton = Bolton(optimizer, loss)
|
||||
with bolton(*args) as _:
|
||||
model.fit()
|
||||
The args required for the context manager can be found in the __call__
|
||||
method.
|
||||
|
||||
For more details on the strong convexity requirements, see:
|
||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||
Descent-based Analytics by Xi Wu et. al.
|
||||
"""
|
||||
def __init__(self,
|
||||
def __init__(self, # pylint: disable=super-init-not-called
|
||||
optimizer: optimizer_v2.OptimizerV2,
|
||||
loss: StrongConvexMixin,
|
||||
dtype=tf.float32,
|
||||
|
@ -45,10 +123,11 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
Args:
|
||||
optimizer: Optimizer_v2 or subclass to be used as the optimizer
|
||||
(wrapped).
|
||||
loss: StrongConvexLoss function that the model is being compiled with.
|
||||
"""
|
||||
|
||||
if not isinstance(loss, StrongConvexMixin):
|
||||
raise ValueError("loss function must be a Strongly Convex and therfore"
|
||||
raise ValueError("loss function must be a Strongly Convex and therefore "
|
||||
"extend the StrongConvexMixin.")
|
||||
self._private_attributes = ['_internal_optimizer',
|
||||
'dtype',
|
||||
|
@ -58,13 +137,19 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
'class_weights',
|
||||
'input_dim',
|
||||
'n_samples',
|
||||
'n_classes',
|
||||
'n_outputs',
|
||||
'layers',
|
||||
'_model'
|
||||
'batch_size',
|
||||
'_is_init'
|
||||
]
|
||||
self._internal_optimizer = optimizer
|
||||
self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning
|
||||
# rate scheduler, as required for privacy guarantees. This will still need
|
||||
# to get values from the loss function near the time that .fit is called
|
||||
# on the model (when this optimizer will be called as a context manager)
|
||||
self.dtype = dtype
|
||||
self.loss = loss
|
||||
self._is_init = False
|
||||
|
||||
def get_config(self):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
|
@ -75,49 +160,44 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
"""helper method to normalize the weights to the R-ball.
|
||||
|
||||
Args:
|
||||
r: radius of "R-Ball". Scalar to normalize to.
|
||||
force: True to normalize regardless of previous weight values.
|
||||
False to check if weights > R-ball and only normalize then.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
r = self.loss.radius()
|
||||
radius = self.loss.radius()
|
||||
for layer in self.layers:
|
||||
if tf.executing_eagerly():
|
||||
weight_norm = tf.norm(layer.kernel, axis=0)
|
||||
if force:
|
||||
layer.kernel = layer.kernel / (weight_norm / r)
|
||||
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0:
|
||||
layer.kernel = layer.kernel / (weight_norm / r)
|
||||
weight_norm = tf.norm(layer.kernel, axis=0)
|
||||
if force:
|
||||
layer.kernel = layer.kernel / (weight_norm / radius)
|
||||
else:
|
||||
weight_norm = tf.norm(layer.kernel, axis=0)
|
||||
if force:
|
||||
layer.kernel = layer.kernel / (weight_norm / r)
|
||||
else:
|
||||
layer.kernel = tf.cond(
|
||||
tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0,
|
||||
lambda: layer.kernel / (weight_norm / r),
|
||||
lambda: layer.kernel
|
||||
)
|
||||
layer.kernel = tf.cond(
|
||||
tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0,
|
||||
lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop
|
||||
lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop
|
||||
)
|
||||
|
||||
def get_noise(self, data_size, input_dim, output_dim, class_weight):
|
||||
def get_noise(self, input_dim, output_dim):
|
||||
"""Sample noise to be added to weights for privacy guarantee
|
||||
|
||||
Args:
|
||||
distribution: the distribution type to pull noise from
|
||||
data_size: the number of samples
|
||||
input_dim: the input dimensionality for the weights
|
||||
output_dim the output dimensionality for the weights
|
||||
|
||||
Returns: noise in shape of layer's weights to be added to the weights.
|
||||
|
||||
"""
|
||||
if not self._is_init:
|
||||
raise Exception('This method must be called from within the optimizer\'s '
|
||||
'context.')
|
||||
loss = self.loss
|
||||
distribution = self.noise_distribution.lower()
|
||||
if distribution == _accepted_distributions[0]: # laplace
|
||||
per_class_epsilon = self.epsilon / (output_dim)
|
||||
l2_sensitivity = (2 *
|
||||
loss.lipchitz_constant(class_weight)) / \
|
||||
(loss.gamma() * data_size)
|
||||
loss.lipchitz_constant(self.class_weights)) / \
|
||||
(loss.gamma() * self.n_samples * self.batch_size)
|
||||
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
|
||||
mean=0,
|
||||
seed=1,
|
||||
|
@ -139,28 +219,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
raise NotImplementedError('Noise distribution: {0} is not '
|
||||
'a valid distribution'.format(distribution))
|
||||
|
||||
def limit_learning_rate(self, beta, gamma):
|
||||
"""Implements learning rate limitation that is required by the bolton
|
||||
method for sensitivity bounding of the strongly convex function.
|
||||
Sets the learning rate to the min(1/beta, 1/(gamma*t))
|
||||
|
||||
Args:
|
||||
is_eager: Whether the model is running in eager mode
|
||||
beta: loss function beta-smoothness
|
||||
gamma: loss function gamma-strongly convex
|
||||
|
||||
Returns: None
|
||||
|
||||
"""
|
||||
numerator = tf.constant(1, dtype=self.dtype)
|
||||
t = tf.cast(self._iterations, self.dtype)
|
||||
# will exist on the internal optimizer
|
||||
if numerator / beta < numerator / (gamma * t):
|
||||
self.learning_rate = numerator / beta
|
||||
else:
|
||||
self.learning_rate = numerator / (gamma * t)
|
||||
|
||||
def from_config(self, *args, **kwargs):
|
||||
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer.from_config(*args, **kwargs)
|
||||
|
@ -176,21 +235,19 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
from _internal_optimizer.
|
||||
|
||||
"""
|
||||
if name == '_private_attributes':
|
||||
return getattr(self, name)
|
||||
elif name in self._private_attributes:
|
||||
if name == '_private_attributes' or name in self._private_attributes:
|
||||
return getattr(self, name)
|
||||
optim = object.__getattribute__(self, '_internal_optimizer')
|
||||
try:
|
||||
return object.__getattribute__(optim, name)
|
||||
except AttributeError:
|
||||
raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'"
|
||||
"".format(
|
||||
self.__class__.__name__,
|
||||
self._internal_optimizer.__class__.__name__,
|
||||
name
|
||||
)
|
||||
)
|
||||
raise AttributeError(
|
||||
"Neither '{0}' nor '{1}' object has attribute '{2}'"
|
||||
"".format(self.__class__.__name__,
|
||||
self._internal_optimizer.__class__.__name__,
|
||||
name
|
||||
)
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
""" Set attribute to self instance if its the internal optimizer.
|
||||
|
@ -205,113 +262,110 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
"""
|
||||
if key == '_private_attributes':
|
||||
object.__setattr__(self, key, value)
|
||||
elif key in key in self._private_attributes:
|
||||
elif key in self._private_attributes:
|
||||
object.__setattr__(self, key, value)
|
||||
else:
|
||||
setattr(self._internal_optimizer, key, value)
|
||||
|
||||
def _resource_apply_dense(self, *args, **kwargs):
|
||||
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer._resource_apply_dense(*args, **kwargs)
|
||||
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def _resource_apply_sparse(self, *args, **kwargs):
|
||||
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs)
|
||||
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def get_updates(self, loss, params):
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
# self.layers = params
|
||||
out = self._internal_optimizer.get_updates(loss, params)
|
||||
self.limit_learning_rate(self.loss.beta(self.class_weights),
|
||||
self.loss.gamma()
|
||||
)
|
||||
self.project_weights_to_r()
|
||||
return out
|
||||
|
||||
def apply_gradients(self, *args, **kwargs):
|
||||
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
# grads_and_vars = kwargs.get('grads_and_vars', None)
|
||||
# grads_and_vars = optimizer_v2._filter_grads(grads_and_vars)
|
||||
# var_list = [v for (_, v) in grads_and_vars]
|
||||
# self.layers = var_list
|
||||
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
|
||||
self.limit_learning_rate(self.loss.beta(self.class_weights),
|
||||
self.loss.gamma()
|
||||
)
|
||||
self.project_weights_to_r()
|
||||
return out
|
||||
|
||||
def minimize(self, *args, **kwargs):
|
||||
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
# self.layers = kwargs.get('var_list', None)
|
||||
out = self._internal_optimizer.minimize(*args, **kwargs)
|
||||
self.limit_learning_rate(self.loss.beta(self.class_weights),
|
||||
self.loss.gamma()
|
||||
)
|
||||
self.project_weights_to_r()
|
||||
return out
|
||||
|
||||
def _compute_gradients(self, *args, **kwargs):
|
||||
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
# self.layers = kwargs.get('var_list', None)
|
||||
return self._internal_optimizer._compute_gradients(*args, **kwargs)
|
||||
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
|
||||
|
||||
def get_gradients(self, *args, **kwargs):
|
||||
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
|
||||
"""
|
||||
# self.layers = kwargs.get('params', None)
|
||||
return self._internal_optimizer.get_gradients(*args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
noise_distribution = self.noise_distribution
|
||||
epsilon = self.epsilon
|
||||
class_weights = self.class_weights
|
||||
n_samples = self.n_samples
|
||||
if noise_distribution not in _accepted_distributions:
|
||||
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
|
||||
'distributions'.format(noise_distribution,
|
||||
_accepted_distributions))
|
||||
self.noise_distribution = noise_distribution
|
||||
self.epsilon = epsilon
|
||||
self.class_weights = class_weights
|
||||
self.n_samples = n_samples
|
||||
"""Context manager call at the beginning of with statement.
|
||||
|
||||
Returns:
|
||||
self, to be used in context manager
|
||||
"""
|
||||
self._is_init = True
|
||||
return self
|
||||
|
||||
def __call__(self,
|
||||
noise_distribution,
|
||||
epsilon,
|
||||
layers,
|
||||
noise_distribution: str,
|
||||
epsilon: float,
|
||||
layers: list,
|
||||
class_weights,
|
||||
n_samples,
|
||||
n_classes,
|
||||
n_outputs,
|
||||
batch_size
|
||||
):
|
||||
"""
|
||||
"""Entry point from context. Accepts required values for bolton method and
|
||||
stores them on the optimizer for use throughout fitting.
|
||||
|
||||
Args:
|
||||
noise_distribution: the noise distribution to pick.
|
||||
see _accepted_distributions and get_noise for
|
||||
possible values.
|
||||
epsilon: privacy parameter. Lower gives more privacy but less utility.
|
||||
class_weights: class_weights used
|
||||
layers: list of Keras/Tensorflow layers. Can be found as model.layers
|
||||
class_weights: class_weights used, which may either be a scalar or 1D
|
||||
tensor with dim == n_classes.
|
||||
n_samples number of rows/individual samples in the training set
|
||||
n_classes: number of output classes
|
||||
layers: list of Keras/Tensorflow layers.
|
||||
n_outputs: number of output classes
|
||||
batch_size: batch size used.
|
||||
"""
|
||||
if epsilon <= 0:
|
||||
raise ValueError('Detected epsilon: {0}. '
|
||||
'Valid range is 0 < epsilon <inf'.format(epsilon))
|
||||
if noise_distribution not in _accepted_distributions:
|
||||
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
|
||||
'distributions'.format(noise_distribution,
|
||||
_accepted_distributions))
|
||||
self.noise_distribution = noise_distribution
|
||||
self.epsilon = epsilon
|
||||
self.class_weights = class_weights
|
||||
self.n_samples = n_samples
|
||||
self.n_classes = n_classes
|
||||
self.learning_rate.initialize(self.loss.beta(class_weights),
|
||||
self.loss.gamma()
|
||||
)
|
||||
self.epsilon = _ops.convert_to_tensor_v2(epsilon, dtype=self.dtype)
|
||||
self.class_weights = _ops.convert_to_tensor_v2(class_weights,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.n_samples = _ops.convert_to_tensor_v2(n_samples,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.n_outputs = _ops.convert_to_tensor_v2(n_outputs,
|
||||
dtype=self.dtype
|
||||
)
|
||||
self.layers = layers
|
||||
self.batch_size = _ops.convert_to_tensor_v2(batch_size,
|
||||
dtype=self.dtype
|
||||
)
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
|
@ -328,30 +382,21 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
|
||||
|
||||
"""
|
||||
# for param in self.layers:
|
||||
# if param.name.find('kernel') != -1 or param.name.find('weight') != -1:
|
||||
# input_dim = param.numpy().shape[0]
|
||||
# print(param)
|
||||
# noise = -1 * self.get_noise(self.n_samples,
|
||||
# input_dim,
|
||||
# self.n_classes,
|
||||
# self.class_weights
|
||||
# )
|
||||
# print(tf.math.subtract(param, noise))
|
||||
# param.assign(tf.math.subtract(param, noise))
|
||||
self.project_weights_to_r(True)
|
||||
for layer in self.layers:
|
||||
input_dim, output_dim = layer.kernel.shape
|
||||
noise = self.get_noise(self.n_samples,
|
||||
input_dim,
|
||||
input_dim = layer.kernel.shape[0]
|
||||
output_dim = layer.units
|
||||
noise = self.get_noise(input_dim,
|
||||
output_dim,
|
||||
self.class_weights
|
||||
)
|
||||
layer.kernel = tf.math.add(layer.kernel, noise)
|
||||
self.noise_distribution = None
|
||||
self.learning_rate.de_initialize()
|
||||
self.epsilon = -1
|
||||
self.batch_size = -1
|
||||
self.class_weights = None
|
||||
self.n_samples = None
|
||||
self.input_dim = None
|
||||
self.n_classes = None
|
||||
self.n_outputs = None
|
||||
self.layers = None
|
||||
self._is_init = False
|
||||
|
|
|
@ -22,16 +22,18 @@ from tensorflow.python.platform import test
|
|||
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.regularizers import L1L2
|
||||
from tensorflow.python.keras.initializers import constant
|
||||
from tensorflow.python.keras import losses
|
||||
from tensorflow.python.keras.models import Model
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
from tensorflow.python import ops as _ops
|
||||
from absl.testing import parameterized
|
||||
from privacy.bolton.loss import StrongConvexMixin
|
||||
from privacy.bolton import optimizer as opt
|
||||
|
||||
|
||||
|
||||
class TestModel(Model):
|
||||
"""
|
||||
Bolton episilon-delta model
|
||||
|
@ -46,10 +48,10 @@ class TestModel(Model):
|
|||
Descent-based Analytics by Xi Wu et. al.
|
||||
"""
|
||||
|
||||
def __init__(self, n_classes=2):
|
||||
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
|
||||
"""
|
||||
Args:
|
||||
n_classes: number of output classes to predict.
|
||||
n_outputs: number of output neurons
|
||||
epsilon: level of privacy guarantee
|
||||
noise_distribution: distribution to pull weight perturbations from
|
||||
weights_initializer: initializer for weights
|
||||
|
@ -57,13 +59,13 @@ class TestModel(Model):
|
|||
dtype: data type to use for tensors
|
||||
"""
|
||||
super(TestModel, self).__init__(name='bolton', dynamic=False)
|
||||
self.n_classes = n_classes
|
||||
self.layer_input_shape = (16, 1)
|
||||
self.n_outputs = n_outputs
|
||||
self.layer_input_shape = input_shape
|
||||
self.output_layer = tf.keras.layers.Dense(
|
||||
self.n_classes,
|
||||
input_shape=self.layer_input_shape,
|
||||
kernel_regularizer=L1L2(l2=1),
|
||||
kernel_initializer='glorot_uniform',
|
||||
self.n_outputs,
|
||||
input_shape=self.layer_input_shape,
|
||||
kernel_regularizer=L1L2(l2=1),
|
||||
kernel_initializer=constant(init_value),
|
||||
)
|
||||
|
||||
|
||||
|
@ -84,7 +86,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
def __init__(self, reg_lambda, C, radius_constant, name='test'):
|
||||
super(TestLoss, self).__init__(name=name)
|
||||
self.reg_lambda = reg_lambda
|
||||
self.C = C
|
||||
self.C = C # pylint: disable=invalid-name
|
||||
self.radius_constant = radius_constant
|
||||
|
||||
def radius(self):
|
||||
|
@ -93,7 +95,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
Returns: radius
|
||||
|
||||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
|
||||
|
||||
def gamma(self):
|
||||
""" Gamma strongly convex
|
||||
|
@ -125,13 +127,17 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
"""
|
||||
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
|
||||
def call(self, val0, val1):
|
||||
def call(self, y_true, y_pred):
|
||||
"""Loss function that is minimized at the mean of the input points."""
|
||||
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
|
||||
return 0.5 * tf.reduce_sum(
|
||||
tf.math.squared_difference(y_true, y_pred),
|
||||
axis=1
|
||||
)
|
||||
|
||||
def max_class_weight(self, class_weight):
|
||||
def max_class_weight(self, class_weight, dtype=tf.float32):
|
||||
if class_weight is None:
|
||||
return 1
|
||||
raise NotImplementedError('')
|
||||
|
||||
def kernel_regularizer(self):
|
||||
return L1L2(l2=self.reg_lambda)
|
||||
|
@ -182,18 +188,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""Bolton Optimizer tests"""
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'branch beta',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [tf.Variable(2, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(0.5, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'branch gamma',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [tf.Variable(1, dtype=tf.float32),
|
||||
tf.Variable(1, dtype=tf.float32)],
|
||||
'result': tf.Variable(1, dtype=tf.float32),
|
||||
'test_attr': 'learning_rate'},
|
||||
{'testcase_name': 'getattr',
|
||||
'fn': '__getattr__',
|
||||
'args': ['dtype'],
|
||||
|
@ -202,8 +196,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
{'testcase_name': 'project_weights_to_r',
|
||||
'fn': 'project_weights_to_r',
|
||||
'args': ['dtype'],
|
||||
'result': tf.float32,
|
||||
'test_attr': None},
|
||||
'result': None,
|
||||
'test_attr': ''},
|
||||
])
|
||||
def test_fn(self, fn, args, result, test_attr):
|
||||
"""test that a fn of Bolton optimizer is working as expected.
|
||||
|
@ -218,15 +212,176 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""
|
||||
tf.random.set_seed(1)
|
||||
loss = TestLoss(1, 1, 1)
|
||||
private = opt.Bolton(TestOptimizer(), loss)
|
||||
res = getattr(private, fn, None)(*args)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(1)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
bolton._is_init = True
|
||||
bolton.layers = model.layers
|
||||
bolton.epsilon = 2
|
||||
bolton.noise_distribution = 'laplace'
|
||||
bolton.n_outputs = 1
|
||||
bolton.n_samples = 1
|
||||
res = getattr(bolton, fn, None)(*args)
|
||||
if test_attr is not None:
|
||||
res = getattr(private, test_attr, None)
|
||||
res = getattr(bolton, test_attr, None)
|
||||
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
|
||||
res = res.numpy()
|
||||
result = result.numpy()
|
||||
self.assertEqual(res, result)
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': '1 value project to r=1',
|
||||
'r': 1,
|
||||
'init_value': 2,
|
||||
'shape': (1,),
|
||||
'n_out': 1,
|
||||
'result': [[1]]},
|
||||
{'testcase_name': '2 value project to r=1',
|
||||
'r': 1,
|
||||
'init_value': 2,
|
||||
'shape': (2,),
|
||||
'n_out': 1,
|
||||
'result': [[0.707107], [0.707107]]},
|
||||
{'testcase_name': '1 value project to r=2',
|
||||
'r': 2,
|
||||
'init_value': 3,
|
||||
'shape': (1,),
|
||||
'n_out': 1,
|
||||
'result': [[2]]},
|
||||
{'testcase_name': 'no project',
|
||||
'r': 2,
|
||||
'init_value': 1,
|
||||
'shape': (1,),
|
||||
'n_out': 1,
|
||||
'result': [[1]]},
|
||||
])
|
||||
def test_project(self, r, shape, n_out, init_value, result):
|
||||
"""test that a fn of Bolton optimizer is working as expected.
|
||||
|
||||
Args:
|
||||
fn: method of Optimizer to test
|
||||
args: args to optimizer fn
|
||||
result: the expected result
|
||||
test_attr: None if the fn returns the test result. Otherwise, this is
|
||||
the attribute of Bolton to check against result with.
|
||||
|
||||
"""
|
||||
tf.random.set_seed(1)
|
||||
@tf.function
|
||||
def project_fn(r):
|
||||
loss = TestLoss(1, 1, r)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(n_out, shape, init_value)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
bolton._is_init = True
|
||||
bolton.layers = model.layers
|
||||
bolton.epsilon = 2
|
||||
bolton.noise_distribution = 'laplace'
|
||||
bolton.n_outputs = 1
|
||||
bolton.n_samples = 1
|
||||
bolton.project_weights_to_r()
|
||||
return _ops.convert_to_tensor_v2(bolton.layers[0].kernel, tf.float32)
|
||||
res = project_fn(r)
|
||||
self.assertAllClose(res, result)
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'normal values',
|
||||
'epsilon': 2,
|
||||
'noise': 'laplace',
|
||||
'class_weights': 1},
|
||||
])
|
||||
def test_context_manager(self, noise, epsilon, class_weights):
|
||||
"""Tests the context manager functionality of the optimizer
|
||||
|
||||
Args:
|
||||
noise: noise distribution to pick
|
||||
epsilon: epsilon privacy parameter to use
|
||||
class_weights: class_weights to use
|
||||
"""
|
||||
@tf.function
|
||||
def test_run():
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
with bolton(noise, epsilon, model.layers, class_weights, 1, 1, 1) as _:
|
||||
pass
|
||||
return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32)
|
||||
epsilon = test_run()
|
||||
self.assertEqual(epsilon.numpy(), -1)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'invalid noise',
|
||||
'epsilon': 1,
|
||||
'noise': 'not_valid',
|
||||
'err_msg': 'Detected noise distribution: not_valid not one of:'},
|
||||
{'testcase_name': 'invalid epsilon',
|
||||
'epsilon': -1,
|
||||
'noise': 'laplace',
|
||||
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
|
||||
])
|
||||
def test_context_domains(self, noise, epsilon, err_msg):
|
||||
"""
|
||||
|
||||
Args:
|
||||
noise: noise distribution to pick
|
||||
epsilon: epsilon privacy parameter to use
|
||||
err_msg: the expected error message
|
||||
|
||||
"""
|
||||
|
||||
@tf.function
|
||||
def test_run(noise, epsilon):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
with bolton(noise, epsilon, model.layers, 1, 1, 1, 1) as _:
|
||||
pass
|
||||
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||
test_run(noise, epsilon)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'fn: get_noise',
|
||||
'fn': 'get_noise',
|
||||
'args': [1, 1],
|
||||
'err_msg': 'ust be called from within the optimizer\'s context'},
|
||||
])
|
||||
def test_not_in_context(self, fn, args, err_msg):
|
||||
"""Tests that the expected functions raise errors when not in context.
|
||||
|
||||
Args:
|
||||
fn: the function to test
|
||||
args: the arguments for said function
|
||||
err_msg: expected error message
|
||||
"""
|
||||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
getattr(bolton, fn)(*args)
|
||||
|
||||
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
|
||||
test_run(fn, args)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'fn: get_updates',
|
||||
'fn': 'get_updates',
|
||||
|
@ -267,27 +422,33 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
optimizer = opt.Bolton(optimizer, loss)
|
||||
model = TestModel(2)
|
||||
bolton = opt.Bolton(optimizer, loss)
|
||||
model = TestModel(3)
|
||||
model.compile(optimizer, loss)
|
||||
model.layers[0].kernel_initializer(model.layer_input_shape)
|
||||
print(model.layers[0].__dict__)
|
||||
with optimizer('laplace', 2, model.layers, 1, 1, model.n_classes):
|
||||
self.assertEqual(
|
||||
getattr(optimizer, fn, lambda: 'fn not found')(*args),
|
||||
'test'
|
||||
)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
bolton._is_init = True
|
||||
bolton.layers = model.layers
|
||||
bolton.epsilon = 2
|
||||
bolton.noise_distribution = 'laplace'
|
||||
bolton.n_outputs = 1
|
||||
bolton.n_samples = 1
|
||||
self.assertEqual(
|
||||
getattr(bolton, fn, lambda: 'fn not found')(*args),
|
||||
'test'
|
||||
)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'fn: limit_learning_rate',
|
||||
'fn': 'limit_learning_rate',
|
||||
'args': [1, 1, 1]},
|
||||
{'testcase_name': 'fn: project_weights_to_r',
|
||||
'fn': 'project_weights_to_r',
|
||||
'args': []},
|
||||
{'testcase_name': 'fn: get_noise',
|
||||
'fn': 'get_noise',
|
||||
'args': [1, 1, 1, 1]},
|
||||
'args': [1, 1]},
|
||||
])
|
||||
def test_not_reroute_fn(self, fn, args):
|
||||
"""Test that a fn that should not be rerouted to the internal optimizer is
|
||||
|
@ -297,11 +458,30 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
fn: fn to test
|
||||
args: arguments to that fn
|
||||
"""
|
||||
optimizer = TestOptimizer()
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = opt.Bolton(optimizer, loss)
|
||||
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
|
||||
'test')
|
||||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
model.n_outputs))
|
||||
bolton._is_init = True
|
||||
bolton.noise_distribution = 'laplace'
|
||||
bolton.epsilon = 1
|
||||
bolton.layers = model.layers
|
||||
bolton.class_weights = 1
|
||||
bolton.n_samples = 1
|
||||
bolton.batch_size = 1
|
||||
bolton.n_outputs = 1
|
||||
res = getattr(bolton, fn, lambda: 'test')(*args)
|
||||
if res != 'test':
|
||||
res = 1
|
||||
else:
|
||||
res = 0
|
||||
return _ops.convert_to_tensor_v2(res, dtype=tf.float32)
|
||||
self.assertNotEqual(test_run(fn, args), 0)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'attr: _iterations',
|
||||
|
@ -323,8 +503,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'attr does not exist',
|
||||
'attr': '_not_valid'}
|
||||
{'testcase_name': 'attr does not exist',
|
||||
'attr': '_not_valid'}
|
||||
])
|
||||
def test_attribute_error(self, attr):
|
||||
""" test that attribute of internal optimizer is correctly rerouted to
|
||||
|
@ -340,5 +520,54 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
with self.assertRaises(AttributeError):
|
||||
getattr(optimizer, attr)
|
||||
|
||||
class SchedulerTest(keras_parameterized.TestCase):
|
||||
"""GammaBeta Scheduler tests"""
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'not in context',
|
||||
'err_msg': 'Please initialize the GammaBetaDecreasingStep Learning Rate'
|
||||
' Scheduler'
|
||||
}
|
||||
])
|
||||
def test_bad_call(self, err_msg):
|
||||
""" test that attribute of internal optimizer is correctly rerouted to
|
||||
the internal optimizer
|
||||
|
||||
Args:
|
||||
attr: attribute to test
|
||||
result: result after checking attribute
|
||||
"""
|
||||
scheduler = opt.GammaBetaDecreasingStep()
|
||||
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
|
||||
scheduler(1)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'step 1',
|
||||
'step': 1,
|
||||
'res': 0.5},
|
||||
{'testcase_name': 'step 2',
|
||||
'step': 2,
|
||||
'res': 0.5},
|
||||
{'testcase_name': 'step 3',
|
||||
'step': 3,
|
||||
'res': 0.333333333},
|
||||
])
|
||||
def test_call(self, step, res):
|
||||
""" test that attribute of internal optimizer is correctly rerouted to
|
||||
the internal optimizer
|
||||
|
||||
Args:
|
||||
attr: attribute to test
|
||||
result: result after checking attribute
|
||||
"""
|
||||
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
|
||||
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||
scheduler = opt.GammaBetaDecreasingStep()
|
||||
scheduler.initialize(beta, gamma)
|
||||
step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
|
||||
lr = scheduler(step)
|
||||
self.assertAllClose(lr.numpy(), res)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in a new issue