Bolton created as optimizer with context manager usage.

Unit tests included.
Additional loss functions TBD.
This commit is contained in:
Christopher Choquette Choo 2019-06-17 13:25:30 -04:00
parent ec18db5ec5
commit 935d6e8480
6 changed files with 685 additions and 661 deletions

View file

@ -102,7 +102,7 @@ class StrongConvexMixin:
return tf.math.reduce_max(class_weight)
class StrongConvexHuber(losses.Huber, StrongConvexMixin):
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
"""Strong Convex version of Huber loss using l2 weight regularization.
"""
@ -112,7 +112,6 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
radius_constant: float,
delta: float,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name: str = 'huber',
dtype=tf.float32):
"""Constructor.
@ -137,13 +136,17 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
raise ValueError('radius_constant: {0}, should be >= 0'.format(
radius_constant
))
self.C = C
if delta <= 0:
raise ValueError('delta: {0}, should be >= 0'.format(
delta
))
self.C = C # pylint: disable=invalid-name
self.delta = delta
self.radius_constant = radius_constant
self.dtype = dtype
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexHuber, self).__init__(
delta=delta,
name=name,
name='huber',
reduction=reduction,
)
@ -151,26 +154,25 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
"""Compute loss
Args:
y_true: Ground truth values.
y_true: Ground truth values. One
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
# return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
h = self._fn_kwargs['delta']
h = self.delta
z = y_pred * y_true
one = tf.constant(1, dtype=self.dtype)
four = tf.constant(4, dtype=self.dtype)
if z > one + h:
return z - z
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
elif tf.math.abs(one - z) <= h:
return one / (four * h) * tf.math.pow(one + h - z, 2)
elif z < one - h:
return one - z
else:
raise ValueError('')
raise ValueError('') # shouldn't be possible to get here.
def radius(self):
"""See super class.
@ -186,7 +188,7 @@ class StrongConvexHuber(losses.Huber, StrongConvexMixin):
"""See super class.
"""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'],
delta = _ops.convert_to_tensor_v2(self.delta,
dtype=self.dtype
)
return self.C * max_class_weight / (delta *
@ -250,7 +252,7 @@ class StrongConvexBinaryCrossentropy(
radius_constant
))
self.dtype = dtype
self.C = C
self.C = C # pylint: disable=invalid-name
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexBinaryCrossentropy, self).__init__(
reduction=reduction,
@ -306,7 +308,7 @@ class StrongConvexBinaryCrossentropy(
this loss function to be strongly convex.
:return:
"""
return L1L2(l2=self.reg_lambda)
return L1L2(l2=self.reg_lambda/2)
# class StrongConvexSparseCategoricalCrossentropy(

View file

@ -79,7 +79,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
'reg_lambda': 1,
'C': 1,
'radius_constant': 1
},
}, # pylint: disable=invalid-name
])
def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments
@ -107,7 +107,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
'reg_lambda': -1,
'C': 1,
'radius_constant': 1
},
}, # pylint: disable=invalid-name
])
def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError
@ -180,7 +180,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1],
'args': [],
'result': L1L2(l2=1),
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):

View file

@ -23,8 +23,6 @@ from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton.optimizer import Bolton
_accepted_distributions = ['laplace']
class BoltonModel(Model):
"""
@ -41,41 +39,28 @@ class BoltonModel(Model):
"""
def __init__(self,
n_classes,
# noise_distribution='laplace',
n_outputs,
seed=1,
dtype=tf.float32
):
""" private constructor.
Args:
n_classes: number of output classes to predict.
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
n_outputs: number of output classes to predict.
seed: random seed to use
dtype: data type to use for tensors
"""
# if noise_distribution not in _accepted_distributions:
# raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
# 'distributions'.format(noise_distribution,
# _accepted_distributions))
# if epsilon <= 0:
# raise ValueError('Detected epsilon: {0}. '
# 'Valid range is 0 < epsilon <inf'.format(epsilon))
# self.epsilon = epsilon
super(BoltonModel, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.force = False
# self.noise_distribution = noise_distribution
if n_outputs <= 0:
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
n_outputs
))
self.n_outputs = n_outputs
self.seed = seed
self.__in_fit = False
self._layers_instantiated = False
# self._callback = MyCustomCallback()
self._dtype = dtype
def call(self, inputs):
def call(self, inputs, training=False): # pylint: disable=arguments-differ
"""Forward pass of network
Args:
@ -87,37 +72,30 @@ class BoltonModel(Model):
return self.output_layer(inputs)
def compile(self,
optimizer='SGD',
loss=None,
optimizer,
loss,
metrics=None,
loss_weights=None,
sample_weight_mode=None,
weighted_metrics=None,
target_tensors=None,
distribute=None,
**kwargs):
kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in Bolton method is SGD.
"""
for key, val in StrongConvexMixin.__dict__.items():
if callable(val) and getattr(loss, key, None) is None:
raise ValueError("Please ensure you are passing a valid StrongConvex "
"loss that has all the required methods "
"implemented. "
"Required method: {0} not found".format(key))
if not isinstance(loss, StrongConvexMixin):
raise ValueError("loss function must be a Strongly Convex and therefore "
"extend the StrongConvexMixin.")
if not self._layers_instantiated: # compile may be called multiple times
kernel_intiializer = kwargs.get('kernel_initializer',
tf.initializers.GlorotUniform)
# for instance, if the input/outputs are not defined until fit.
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
self.n_outputs,
kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_intiializer(),
kernel_initializer=kernel_initializer(),
)
# if we don't do regularization here, we require the user to
# re-instantiate the model each time they want to change the penalty
# weighting
self._layers_instantiated = True
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
if not isinstance(optimizer, Bolton):
optimizer = optimizers.get(optimizer)
optimizer = Bolton(optimizer, loss)
@ -133,69 +111,16 @@ class BoltonModel(Model):
**kwargs
)
# def _post_fit(self, x, n_samples):
# """Implements 1-time weight changes needed for Bolton method.
# In this case, specifically implements the noise addition
# assuming a strongly convex function.
#
# Args:
# x: inputs
# n_samples: number of samples in the inputs. In case the number
# cannot be readily determined by inspecting x.
#
# Returns:
#
# """
# data_size = None
# if n_samples is not None:
# data_size = n_samples
# elif hasattr(x, 'shape'):
# data_size = x.shape[0]
# elif hasattr(x, "__len__"):
# data_size = len(x)
# elif data_size is None:
# if n_samples is None:
# raise ValueError("Unable to detect the number of training "
# "samples and n_smaples was None. "
# "either pass a dataset with a .shape or "
# "__len__ attribute or explicitly pass the "
# "number of samples as n_smaples.")
# for layer in self.layers:
# # layer.kernel = layer.kernel + self._get_noise(
# # data_size
# # )
# input_dim = layer.kernel.numpy().shape[0]
# layer.kernel = layer.kernel + self.optimizer.get_noise(
# self.loss,
# data_size,
# input_dim,
# self.n_classes,
# self.class_weight
# )
def fit(self,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
n_samples=None,
epsilon=2,
noise_distribution='laplace',
**kwargs):
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
Requirements are as follows:
@ -207,92 +132,101 @@ class BoltonModel(Model):
Args:
n_samples: the number of individual samples in x.
epsilon: privacy parameter, which trades off between utility an privacy.
See the bolton paper for more description.
noise_distribution: the distribution to pull noise from.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
Returns:
See the super method for descriptions on the rest of the arguments.
"""
self.__in_fit = True
# cb = [self.optimizer.callbacks]
# if callbacks is not None:
# cb.extend(callbacks)
# callbacks = cb
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
# self.class_weight = class_weight
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, "__len__"):
data_size = len(x)
else:
data_size = None
batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch,
x
)
# inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit()
if batch_size_ is None:
raise ValueError('batch_size: {0} is an '
'invalid value'.format(batch_size_))
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight,
n_samples,
self.n_classes,
) as optim:
data_size,
self.n_outputs,
batch_size_,
) as _:
out = super(BoltonModel, self).fit(x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
return out
def fit_generator(self,
generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
callbacks=None,
validation_data=None,
validation_steps=None,
validation_freq=1,
class_weight=None,
max_queue_size=10,
workers=1,
use_multiprocessing=False,
shuffle=True,
initial_epoch=0,
n_samples=None
):
noise_distribution='laplace',
epsilon=2,
n_samples=None,
steps_per_epoch=None,
**kwargs
): # pylint: disable=arguments-differ
"""
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
n_samples: number of individual samples in x
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
Bolton paper for more description.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
See the super method for descriptions on the rest of the arguments.
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(BoltonModel, self).fit_generator(
generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_data=validation_data,
validation_steps=validation_steps,
validation_freq=validation_freq,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
shuffle=shuffle,
initial_epoch=initial_epoch
)
if not self.__in_fit:
self._post_fit(generator, n_samples)
if n_samples is not None:
data_size = n_samples
elif hasattr(generator, 'shape'):
data_size = generator.shape[0]
elif hasattr(generator, "__len__"):
data_size = len(generator)
else:
data_size = None
batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch,
generator
)
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight,
data_size,
self.n_outputs,
batch_size
) as _:
out = super(BoltonModel, self).fit_generator(
generator,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
**kwargs
)
return out
def calculate_class_weights(self,
@ -336,7 +270,7 @@ class BoltonModel(Model):
"class_weights=%s" % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError("You must pass a value for num_classes if"
raise ValueError("You must pass a value for num_classes if "
"creating an array of class_weights")
# performing class weight calculation
if class_weights is None:
@ -357,195 +291,9 @@ class BoltonModel(Model):
"1D array".format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
)
return class_weights
# def _project_weights_to_r(self, r, force=False):
# """helper method to normalize the weights to the R-ball.
#
# Args:
# r: radius of "R-Ball". Scalar to normalize to.
# force: True to normalize regardless of previous weight values.
# False to check if weights > R-ball and only normalize then.
#
# Returns:
#
# """
# for layer in self.layers:
# weight_norm = tf.norm(layer.kernel, axis=0)
# if force:
# layer.kernel = layer.kernel / (weight_norm / r)
# elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
# layer.kernel = layer.kernel / (weight_norm / r)
# def _get_noise(self, distribution, data_size):
# """Sample noise to be added to weights for privacy guarantee
#
# Args:
# distribution: the distribution type to pull noise from
# data_size: the number of samples
#
# Returns: noise in shape of layer's weights to be added to the weights.
#
# """
# distribution = distribution.lower()
# input_dim = self.layers[0].kernel.numpy().shape[0]
# loss = self.loss
# if distribution == _accepted_distributions[0]: # laplace
# per_class_epsilon = self.epsilon / (self.n_classes)
# l2_sensitivity = (2 *
# loss.lipchitz_constant(self.class_weight)) / \
# (loss.gamma() * data_size)
# unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
# mean=0,
# seed=1,
# stddev=1.0,
# dtype=self._dtype)
# unit_vector = unit_vector / tf.math.sqrt(
# tf.reduce_sum(tf.math.square(unit_vector), axis=0)
# )
#
# beta = l2_sensitivity / per_class_epsilon
# alpha = input_dim # input_dim
# gamma = tf.random.gamma([self.n_classes],
# alpha,
# beta=1 / beta,
# seed=1,
# dtype=self._dtype
# )
# return unit_vector * gamma
# raise NotImplementedError('Noise distribution: {0} is not '
# 'a valid distribution'.format(distribution))
if __name__ == '__main__':
import tensorflow as tf
import os
import time
import matplotlib.pyplot as plt
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
origin=_URL,
extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def load(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image)
w = tf.shape(image)[1]
w = w // 2
real_image = image[:, :w, :]
input_image = image[:, w:, :]
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
return input_image, real_image
inp, re = load(PATH + 'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(re / 255.0)
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
def normalize(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
# resizing to 286 x 286 x 3
input_image, real_image = resize(input_image, real_image, 286, 286)
# randomly cropping to 256 x 256 x 3
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
# random mirroring
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image
def load_image_train(image_file):
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
def load_image_test(image_file):
input_image, real_image = load(image_file)
input_image, real_image = resize(input_image, real_image,
IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)
# steps_per_epoch = training_utils.infer_steps_for_dataset(
# train_dataset, None, epochs=1, steps_name='steps')
# for batch in train_dataset:
# print(batch[1].shape)
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)
be = BoltonModel(3, 2)
from tensorflow.python.keras.optimizer_v2 import adam
from privacy.bolton import loss
test = adam.Adam()
l = loss.StrongConvexBinaryCrossentropy(1, 2, 1)
be.compile(test, l)
print("Eager exeuction: {0}".format(tf.executing_eagerly()))
be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1)

View file

@ -19,25 +19,22 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import losses
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import test_util
from privacy.bolton import model
from privacy.bolton.loss import StrongConvexMixin
from absl.testing import parameterized
from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
from absl.testing import parameterized
from privacy.bolton import model
from privacy.bolton.optimizer import Bolton
from privacy.bolton.loss import StrongConvexMixin
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = C
self.C = C # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
@ -78,13 +75,17 @@ class TestLoss(losses.Loss, StrongConvexMixin):
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, val0, val1):
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight):
if class_weight is None:
return 1
raise ValueError('')
def kernel_regularizer(self):
return L1L2(l2=self.reg_lambda)
@ -116,125 +117,91 @@ class InitTests(keras_parameterized.TestCase):
@parameterized.named_parameters([
{'testcase_name': 'normal',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'laplace',
'seed': 1
'n_outputs': 1,
},
{'testcase_name': 'extreme range',
'n_classes': 5,
'epsilon': 0.1,
'noise_distribution': 'laplace',
'seed': 10
},
{'testcase_name': 'extreme range2',
'n_classes': 50,
'epsilon': 10,
'noise_distribution': 'laplace',
'seed': 100
{'testcase_name': 'many outputs',
'n_outputs': 100,
},
])
def test_init_params(
self, n_classes, epsilon, noise_distribution, seed):
def test_init_params(self, n_outputs):
"""test initialization of BoltonModel
Args:
n_outputs: number of output neurons
"""
# test valid domains for each variable
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
seed
)
self.assertIsInstance(clf, model.Bolton)
clf = model.BoltonModel(n_outputs)
self.assertIsInstance(clf, model.BoltonModel)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'not_valid',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid epsilon',
'n_classes': 1,
'epsilon': -1,
'noise_distribution': 'laplace',
'weights_initializer': tf.initializers.GlorotUniform(),
{'testcase_name': 'invalid n_outputs',
'n_outputs': -1,
},
])
def test_bad_init_params(
self,
n_classes,
epsilon,
noise_distribution,
weights_initializer):
def test_bad_init_params(self, n_outputs):
"""test bad initializations of BoltonModel that should raise errors
Args:
n_outputs: number of output neurons
"""
# test invalid domains for each variable, especially noise
seed = 1
with self.assertRaises(ValueError):
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer,
seed
)
model.BoltonModel(n_outputs)
@parameterized.named_parameters([
{'testcase_name': 'string compile',
'n_classes': 1,
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'adam',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'test compile',
'n_classes': 100,
'n_outputs': 100,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid weights initializer',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
'weights_initializer': 'not_valid',
},
])
def test_compile(self, n_classes, loss, optimizer, weights_initializer):
def test_compile(self, n_outputs, loss, optimizer):
"""test compilation of BoltonModel
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
"""
# test compilation of valid tf.optimizer and tf.loss
epsilon = 1
noise_distribution = 'laplace'
with self.cached_session():
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer
)
clf = model.BoltonModel(n_outputs)
clf.compile(optimizer, loss)
self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([
{'testcase_name': 'Not strong loss',
'n_classes': 1,
'n_outputs': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_classes': 1,
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
])
def test_bad_compile(self, n_classes, loss, optimizer):
def test_bad_compile(self, n_outputs, loss, optimizer):
"""test bad compilations of BoltonModel that should raise errors
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
"""
# test compilaton of invalid tf.optimizer and non instantiated loss.
epsilon = 1
noise_distribution = 'laplace'
weights_initializer = tf.initializers.GlorotUniform()
with self.cached_session():
with self.assertRaises((ValueError, AttributeError)):
clf = model.Bolton(n_classes,
epsilon,
noise_distribution,
weights_initializer
)
clf = model.BoltonModel(n_outputs)
clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
"""
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
@ -245,10 +212,9 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
t: one of 'train', 'val', 'test'
generator: False for array, True for generator
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_classes)
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
"""
x_stack = []
y_stack = []
@ -269,25 +235,39 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
def _do_fit(n_samples,
input_dim,
n_classes,
n_outputs,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
callbacks,
distribution='laplace'):
clf = model.Bolton(n_classes,
epsilon,
distribution
)
"""Helper to instantiate necessary components for fitting and perform a model
fit.
Args:
n_samples: number of samples in dataset
input_dim: the sample dimensionality
n_outputs: number of output neurons
epsilon: privacy parameter
generator: True to create a generator, False to use an iterator
batch_size: batch_size to use
reset_n_samples: True to set _samples to None prior to fitting.
False does nothing
optimizer: instance of TestOptimizer
loss: instance of TestLoss
distribution: distribution to get noise from.
Returns: BoltonModel instsance
"""
clf = model.BoltonModel(n_outputs)
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
n_outputs,
generator=generator
)
y = None
@ -295,25 +275,20 @@ def _do_fit(n_samples,
x = x.shuffle(n_samples//2)
batch_size = None
else:
x, y = _cat_dataset(n_samples, input_dim, n_classes, generator=generator)
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
if reset_n_samples:
n_samples = None
if callbacks is not None:
callbacks = [callbacks]
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
callbacks=callbacks
noise_distribution=distribution,
epsilon=epsilon
)
return clf
class TestCallback(tf.keras.callbacks.Callback):
pass
class FitTests(keras_parameterized.TestCase):
"""Test cases for keras model fitting"""
@ -322,27 +297,29 @@ class FitTests(keras_parameterized.TestCase):
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
'callbacks': TestCallback()
},
])
def test_fit(self, generator, reset_n_samples, callbacks):
def test_fit(self, generator, reset_n_samples):
"""Tests fitting of BoltonModel
Args:
generator: True for generator test, False for iterator test.
reset_n_samples: True to reset the n_samples to None, False does nothing
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
optimizer = Bolton(TestOptimizer(), loss)
n_classes = 2
input_dim = 5
epsilon = 1
@ -358,28 +335,27 @@ class FitTests(keras_parameterized.TestCase):
reset_n_samples,
optimizer,
loss,
callbacks
)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
])
def test_fit_gen(self, generator, reset_n_samples, callbacks):
def test_fit_gen(self, generator):
"""Tests the fit_generator method of BoltonModel
Args:
generator: True to test with a generator dataset
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
clf = model.Bolton(n_classes,
epsilon
)
clf = model.BoltonModel(n_classes)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
@ -405,6 +381,14 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_bad_fit(self, generator, reset_n_samples, distribution):
"""Tests fitting with invalid parameters, which should raise an error
Args:
generator: True to test with generator, False is iterator
reset_n_samples: True to reset the n_samples param to None prior to
passing it to fit
distribution: distribution to get noise from.
"""
with self.assertRaises(ValueError):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
@ -423,7 +407,6 @@ class FitTests(keras_parameterized.TestCase):
reset_n_samples,
optimizer,
loss,
None,
distribution
)
@ -450,7 +433,15 @@ class FitTests(keras_parameterized.TestCase):
num_classes,
result
):
clf = model.Bolton(1, 1)
"""Tests the BOltonModel calculate_class_weights method
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
result: expected result
"""
clf = model.BoltonModel(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes
@ -508,12 +499,21 @@ class FitTests(keras_parameterized.TestCase):
num_classes,
err_msg
):
clf = model.Bolton(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg):
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes
)
"""Tests the BOltonModel calculate_class_weights method with invalid params
which should raise the expected errors.
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
result: expected result
"""
clf = model.BoltonModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights,
class_counts,
num_classes
)
if __name__ == '__main__':

View file

@ -19,9 +19,74 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import math_ops
from tensorflow.python import ops as _ops
from privacy.bolton.loss import StrongConvexMixin
_accepted_distributions = ['laplace']
_accepted_distributions = ['laplace'] # implemented distributions for noising
class GammaBetaDecreasingStep(
optimizer_v2.learning_rate_schedule.LearningRateSchedule
):
"""
Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step)
at each step. A required step for privacy guarantees.
"""
def __init__(self):
self.is_init = False
self.beta = None
self.gamma = None
def __call__(self, step):
"""
returns the learning rate
Args:
step: the current iteration number
Returns:
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
the Bolton privacy requirements.
"""
if not self.is_init:
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
'This is performed automatically by using the '
'{1} as a context manager, '
'as desired'.format(self.__class__.__name__,
Bolton.__class__.__name__
)
)
dtype = self.beta.dtype
one = tf.constant(1, dtype)
return tf.math.minimum(tf.math.reduce_min(one/self.beta),
one/(self.gamma*math_ops.cast(step, dtype))
)
def get_config(self):
"""
config to setup the learning rate scheduler.
"""
return {'beta': self.beta, 'gamma': self.gamma}
def initialize(self, beta, gamma):
"""setup the learning rate scheduler with the beta and gamma values provided
by the loss function. Meant to be used with .fit as the loss params may
depend on values passed to fit.
Args:
beta: Smoothness value. See StrongConvexMixin
gamma: Strong Convexity parameter. See StrongConvexMixin.
"""
self.is_init = True
self.beta = beta
self.gamma = gamma
def de_initialize(self):
"""De initialize the scheduler after fitting, in case another fit call has
different loss parameters.
"""
self.is_init = False
self.beta = None
self.gamma = None
class Bolton(optimizer_v2.OptimizerV2):
@ -31,11 +96,24 @@ class Bolton(optimizer_v2.OptimizerV2):
passed, "Bolton" enables the bolton model to control the learning rate
based on the strongly convex loss.
To use the Bolton method, you must:
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
2. use it as a context manager around your .fit method internals.
This can be accomplished by the following:
optimizer = tf.optimizers.SGD()
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
bolton = Bolton(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
method.
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self,
def __init__(self, # pylint: disable=super-init-not-called
optimizer: optimizer_v2.OptimizerV2,
loss: StrongConvexMixin,
dtype=tf.float32,
@ -45,10 +123,11 @@ class Bolton(optimizer_v2.OptimizerV2):
Args:
optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped).
loss: StrongConvexLoss function that the model is being compiled with.
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError("loss function must be a Strongly Convex and therfore"
raise ValueError("loss function must be a Strongly Convex and therefore "
"extend the StrongConvexMixin.")
self._private_attributes = ['_internal_optimizer',
'dtype',
@ -58,13 +137,19 @@ class Bolton(optimizer_v2.OptimizerV2):
'class_weights',
'input_dim',
'n_samples',
'n_classes',
'n_outputs',
'layers',
'_model'
'batch_size',
'_is_init'
]
self._internal_optimizer = optimizer
self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning
# rate scheduler, as required for privacy guarantees. This will still need
# to get values from the loss function near the time that .fit is called
# on the model (when this optimizer will be called as a context manager)
self.dtype = dtype
self.loss = loss
self._is_init = False
def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
@ -75,49 +160,44 @@ class Bolton(optimizer_v2.OptimizerV2):
"""helper method to normalize the weights to the R-ball.
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Returns:
"""
r = self.loss.radius()
radius = self.loss.radius()
for layer in self.layers:
if tf.executing_eagerly():
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / radius)
else:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
else:
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0,
lambda: layer.kernel / (weight_norm / r),
lambda: layer.kernel
)
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0,
lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop
lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop
)
def get_noise(self, data_size, input_dim, output_dim, class_weight):
def get_noise(self, input_dim, output_dim):
"""Sample noise to be added to weights for privacy guarantee
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
input_dim: the input dimensionality for the weights
output_dim the output dimensionality for the weights
Returns: noise in shape of layer's weights to be added to the weights.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
'context.')
loss = self.loss
distribution = self.noise_distribution.lower()
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (output_dim)
l2_sensitivity = (2 *
loss.lipchitz_constant(class_weight)) / \
(loss.gamma() * data_size)
loss.lipchitz_constant(self.class_weights)) / \
(loss.gamma() * self.n_samples * self.batch_size)
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
mean=0,
seed=1,
@ -139,28 +219,7 @@ class Bolton(optimizer_v2.OptimizerV2):
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def limit_learning_rate(self, beta, gamma):
"""Implements learning rate limitation that is required by the bolton
method for sensitivity bounding of the strongly convex function.
Sets the learning rate to the min(1/beta, 1/(gamma*t))
Args:
is_eager: Whether the model is running in eager mode
beta: loss function beta-smoothness
gamma: loss function gamma-strongly convex
Returns: None
"""
numerator = tf.constant(1, dtype=self.dtype)
t = tf.cast(self._iterations, self.dtype)
# will exist on the internal optimizer
if numerator / beta < numerator / (gamma * t):
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
def from_config(self, *args, **kwargs):
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.from_config(*args, **kwargs)
@ -176,21 +235,19 @@ class Bolton(optimizer_v2.OptimizerV2):
from _internal_optimizer.
"""
if name == '_private_attributes':
return getattr(self, name)
elif name in self._private_attributes:
if name == '_private_attributes' or name in self._private_attributes:
return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer')
try:
return object.__getattribute__(optim, name)
except AttributeError:
raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(
self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
)
raise AttributeError(
"Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
)
def __setattr__(self, key, value):
""" Set attribute to self instance if its the internal optimizer.
@ -205,113 +262,110 @@ class Bolton(optimizer_v2.OptimizerV2):
"""
if key == '_private_attributes':
object.__setattr__(self, key, value)
elif key in key in self._private_attributes:
elif key in self._private_attributes:
object.__setattr__(self, key, value)
else:
setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, *args, **kwargs):
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer._resource_apply_dense(*args, **kwargs)
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
def _resource_apply_sparse(self, *args, **kwargs):
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs)
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = params
out = self._internal_optimizer.get_updates(loss, params)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def apply_gradients(self, *args, **kwargs):
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# grads_and_vars = kwargs.get('grads_and_vars', None)
# grads_and_vars = optimizer_v2._filter_grads(grads_and_vars)
# var_list = [v for (_, v) in grads_and_vars]
# self.layers = var_list
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def minimize(self, *args, **kwargs):
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = kwargs.get('var_list', None)
out = self._internal_optimizer.minimize(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def _compute_gradients(self, *args, **kwargs):
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = kwargs.get('var_list', None)
return self._internal_optimizer._compute_gradients(*args, **kwargs)
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
def get_gradients(self, *args, **kwargs):
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = kwargs.get('params', None)
return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self):
noise_distribution = self.noise_distribution
epsilon = self.epsilon
class_weights = self.class_weights
n_samples = self.n_samples
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
"""Context manager call at the beginning of with statement.
Returns:
self, to be used in context manager
"""
self._is_init = True
return self
def __call__(self,
noise_distribution,
epsilon,
layers,
noise_distribution: str,
epsilon: float,
layers: list,
class_weights,
n_samples,
n_classes,
n_outputs,
batch_size
):
"""
"""Entry point from context. Accepts required values for bolton method and
stores them on the optimizer for use throughout fitting.
Args:
noise_distribution: the noise distribution to pick.
see _accepted_distributions and get_noise for
possible values.
epsilon: privacy parameter. Lower gives more privacy but less utility.
class_weights: class_weights used
layers: list of Keras/Tensorflow layers. Can be found as model.layers
class_weights: class_weights used, which may either be a scalar or 1D
tensor with dim == n_classes.
n_samples number of rows/individual samples in the training set
n_classes: number of output classes
layers: list of Keras/Tensorflow layers.
n_outputs: number of output classes
batch_size: batch size used.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
self.n_classes = n_classes
self.learning_rate.initialize(self.loss.beta(class_weights),
self.loss.gamma()
)
self.epsilon = _ops.convert_to_tensor_v2(epsilon, dtype=self.dtype)
self.class_weights = _ops.convert_to_tensor_v2(class_weights,
dtype=self.dtype
)
self.n_samples = _ops.convert_to_tensor_v2(n_samples,
dtype=self.dtype
)
self.n_outputs = _ops.convert_to_tensor_v2(n_outputs,
dtype=self.dtype
)
self.layers = layers
self.batch_size = _ops.convert_to_tensor_v2(batch_size,
dtype=self.dtype
)
return self
def __exit__(self, *args):
@ -328,30 +382,21 @@ class Bolton(optimizer_v2.OptimizerV2):
"""
# for param in self.layers:
# if param.name.find('kernel') != -1 or param.name.find('weight') != -1:
# input_dim = param.numpy().shape[0]
# print(param)
# noise = -1 * self.get_noise(self.n_samples,
# input_dim,
# self.n_classes,
# self.class_weights
# )
# print(tf.math.subtract(param, noise))
# param.assign(tf.math.subtract(param, noise))
self.project_weights_to_r(True)
for layer in self.layers:
input_dim, output_dim = layer.kernel.shape
noise = self.get_noise(self.n_samples,
input_dim,
input_dim = layer.kernel.shape[0]
output_dim = layer.units
noise = self.get_noise(input_dim,
output_dim,
self.class_weights
)
layer.kernel = tf.math.add(layer.kernel, noise)
self.noise_distribution = None
self.learning_rate.de_initialize()
self.epsilon = -1
self.batch_size = -1
self.class_weights = None
self.n_samples = None
self.input_dim = None
self.n_classes = None
self.n_outputs = None
self.layers = None
self._is_init = False

View file

@ -22,16 +22,18 @@ from tensorflow.python.platform import test
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras.initializers import constant
from tensorflow.python.keras import losses
from tensorflow.python.keras.models import Model
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import test_util
from tensorflow.python import ops as _ops
from absl.testing import parameterized
from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton import optimizer as opt
class TestModel(Model):
"""
Bolton episilon-delta model
@ -46,10 +48,10 @@ class TestModel(Model):
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, n_classes=2):
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
"""
Args:
n_classes: number of output classes to predict.
n_outputs: number of output neurons
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
@ -57,13 +59,13 @@ class TestModel(Model):
dtype: data type to use for tensors
"""
super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.layer_input_shape = (16, 1)
self.n_outputs = n_outputs
self.layer_input_shape = input_shape
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer='glorot_uniform',
self.n_outputs,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer=constant(init_value),
)
@ -84,7 +86,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = C
self.C = C # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
@ -93,7 +95,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
Returns: radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
def gamma(self):
""" Gamma strongly convex
@ -125,13 +127,17 @@ class TestLoss(losses.Loss, StrongConvexMixin):
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, val0, val1):
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight):
def max_class_weight(self, class_weight, dtype=tf.float32):
if class_weight is None:
return 1
raise NotImplementedError('')
def kernel_regularizer(self):
return L1L2(l2=self.reg_lambda)
@ -182,18 +188,6 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Bolton Optimizer tests"""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'branch beta',
'fn': 'limit_learning_rate',
'args': [tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch gamma',
'fn': 'limit_learning_rate',
'args': [tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'getattr',
'fn': '__getattr__',
'args': ['dtype'],
@ -202,8 +196,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
{'testcase_name': 'project_weights_to_r',
'fn': 'project_weights_to_r',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
'result': None,
'test_attr': ''},
])
def test_fn(self, fn, args, result, test_attr):
"""test that a fn of Bolton optimizer is working as expected.
@ -218,15 +212,176 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
private = opt.Bolton(TestOptimizer(), loss)
res = getattr(private, fn, None)(*args)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(1)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
res = getattr(bolton, fn, None)(*args)
if test_attr is not None:
res = getattr(private, test_attr, None)
res = getattr(bolton, test_attr, None)
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
res = res.numpy()
result = result.numpy()
self.assertEqual(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': '1 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
{'testcase_name': '2 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (2,),
'n_out': 1,
'result': [[0.707107], [0.707107]]},
{'testcase_name': '1 value project to r=2',
'r': 2,
'init_value': 3,
'shape': (1,),
'n_out': 1,
'result': [[2]]},
{'testcase_name': 'no project',
'r': 2,
'init_value': 1,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
])
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of Bolton optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of Bolton to check against result with.
"""
tf.random.set_seed(1)
@tf.function
def project_fn(r):
loss = TestLoss(1, 1, r)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(n_out, shape, init_value)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
bolton.project_weights_to_r()
return _ops.convert_to_tensor_v2(bolton.layers[0].kernel, tf.float32)
res = project_fn(r)
self.assertAllClose(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'normal values',
'epsilon': 2,
'noise': 'laplace',
'class_weights': 1},
])
def test_context_manager(self, noise, epsilon, class_weights):
"""Tests the context manager functionality of the optimizer
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
class_weights: class_weights to use
"""
@tf.function
def test_run():
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, class_weights, 1, 1, 1) as _:
pass
return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32)
epsilon = test_run()
self.assertEqual(epsilon.numpy(), -1)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'epsilon': 1,
'noise': 'not_valid',
'err_msg': 'Detected noise distribution: not_valid not one of:'},
{'testcase_name': 'invalid epsilon',
'epsilon': -1,
'noise': 'laplace',
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
])
def test_context_domains(self, noise, epsilon, err_msg):
"""
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
err_msg: the expected error message
"""
@tf.function
def test_run(noise, epsilon):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, 1, 1, 1, 1) as _:
pass
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
test_run(noise, epsilon)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1],
'err_msg': 'ust be called from within the optimizer\'s context'},
])
def test_not_in_context(self, fn, args, err_msg):
"""Tests that the expected functions raise errors when not in context.
Args:
fn: the function to test
args: the arguments for said function
err_msg: expected error message
"""
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
getattr(bolton, fn)(*args)
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
test_run(fn, args)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_updates',
'fn': 'get_updates',
@ -267,27 +422,33 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
optimizer = opt.Bolton(optimizer, loss)
model = TestModel(2)
bolton = opt.Bolton(optimizer, loss)
model = TestModel(3)
model.compile(optimizer, loss)
model.layers[0].kernel_initializer(model.layer_input_shape)
print(model.layers[0].__dict__)
with optimizer('laplace', 2, model.layers, 1, 1, model.n_classes):
self.assertEqual(
getattr(optimizer, fn, lambda: 'fn not found')(*args),
'test'
)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
self.assertEqual(
getattr(bolton, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([
{'testcase_name': 'fn: limit_learning_rate',
'fn': 'limit_learning_rate',
'args': [1, 1, 1]},
{'testcase_name': 'fn: project_weights_to_r',
'fn': 'project_weights_to_r',
'args': []},
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1, 1, 1]},
'args': [1, 1]},
])
def test_not_reroute_fn(self, fn, args):
"""Test that a fn that should not be rerouted to the internal optimizer is
@ -297,11 +458,30 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
fn: fn to test
args: arguments to that fn
"""
optimizer = TestOptimizer()
loss = TestLoss(1, 1, 1)
optimizer = opt.Bolton(optimizer, loss)
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
'test')
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True
bolton.noise_distribution = 'laplace'
bolton.epsilon = 1
bolton.layers = model.layers
bolton.class_weights = 1
bolton.n_samples = 1
bolton.batch_size = 1
bolton.n_outputs = 1
res = getattr(bolton, fn, lambda: 'test')(*args)
if res != 'test':
res = 1
else:
res = 0
return _ops.convert_to_tensor_v2(res, dtype=tf.float32)
self.assertNotEqual(test_run(fn, args), 0)
@parameterized.named_parameters([
{'testcase_name': 'attr: _iterations',
@ -323,8 +503,8 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
)
@parameterized.named_parameters([
{'testcase_name': 'attr does not exist',
'attr': '_not_valid'}
{'testcase_name': 'attr does not exist',
'attr': '_not_valid'}
])
def test_attribute_error(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
@ -340,5 +520,54 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
with self.assertRaises(AttributeError):
getattr(optimizer, attr)
class SchedulerTest(keras_parameterized.TestCase):
"""GammaBeta Scheduler tests"""
@parameterized.named_parameters([
{'testcase_name': 'not in context',
'err_msg': 'Please initialize the GammaBetaDecreasingStep Learning Rate'
' Scheduler'
}
])
def test_bad_call(self, err_msg):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
scheduler(1)
@parameterized.named_parameters([
{'testcase_name': 'step 1',
'step': 1,
'res': 0.5},
{'testcase_name': 'step 2',
'step': 2,
'res': 0.5},
{'testcase_name': 'step 3',
'step': 3,
'res': 0.333333333},
])
def test_call(self, step, res):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
scheduler = opt.GammaBetaDecreasingStep()
scheduler.initialize(beta, gamma)
step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
lr = scheduler(step)
self.assertAllClose(lr.numpy(), res)
if __name__ == '__main__':
test.main()