Working bolton model without unit tests.

-- moving to Bolton Optimizer
Model is now just a convenient wrapper and example for users.
Optimizer holds ALL Bolton privacy requirements.
Optimizer is used as a context manager, and must be passed the model's layers.
Unit tests incomplete, committing for visibility into the design.
This commit is contained in:
Christopher Choquette Choo 2019-06-13 01:01:31 -04:00
parent 751eaead54
commit ec18db5ec5
5 changed files with 913 additions and 345 deletions

View file

@ -18,23 +18,17 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import adagrad
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras import losses
from tensorflow.python.framework import test_util
from privacy.bolton import model
from tensorflow.python.keras.regularizers import L1L2
from absl.testing import parameterized
from privacy.bolton.loss import StrongConvexBinaryCrossentropy
from privacy.bolton.loss import StrongConvexHuber
from privacy.bolton.loss import StrongConvexMixin
from absl.testing import parameterized
from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
class StrongConvexTests(keras_parameterized.TestCase):
class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin"""
@parameterized.named_parameters([
{'testcase_name': 'beta not implemented',
'fn': 'beta',
@ -50,6 +44,12 @@ class StrongConvexTests(keras_parameterized.TestCase):
'args': []},
])
def test_not_implemented(self, fn, args):
"""Test that the given fn's are not implemented on the mixin.
Args:
fn: fn on Mixin to test
args: arguments to fn of Mixin
"""
with self.assertRaises(NotImplementedError):
loss = StrongConvexMixin()
getattr(loss, fn, None)(*args)
@ -60,6 +60,12 @@ class StrongConvexTests(keras_parameterized.TestCase):
'args': []},
])
def test_return_none(self, fn, args):
"""Test that fn of Mixin returns None
Args:
fn: fn of Mixin to test
args: arguments to fn of Mixin
"""
loss = StrongConvexMixin()
ret = getattr(loss, fn, None)(*args)
self.assertEqual(ret, None)
@ -71,44 +77,56 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'c': 1,
'C': 1,
'radius_constant': 1
},
])
def test_init_params(self, reg_lambda, c, radius_constant):
def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'c': -1,
'C': -1,
'radius_constant': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'c': 1,
'C': 1,
'radius_constant': -1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'c': 1,
'C': 1,
'radius_constant': 1
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant):
def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant)
StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive',
'logits': [10000],
'y_true': [1],
'result': 0,
'logits': [10000],
'y_true': [1],
'result': 0,
},
{'testcase_name': 'positive gradient negative logits',
'logits': [-10000],
@ -127,6 +145,12 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
},
])
def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
@ -160,6 +184,13 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexBinaryCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
@ -183,6 +214,12 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
self.assertIsInstance(loss, StrongConvexHuber)
@ -214,18 +251,24 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
StrongConvexHuber(reg_lambda, c, radius_constant, delta)
# test the bounds and test varied delta's
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
'logits': 2.1,
'y_true': 1,
'delta': 1,
'result': 0,
'logits': 2.1,
'y_true': 1,
'delta': 1,
'result': 0,
},
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
'logits': 1.9,
@ -277,6 +320,12 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_calculation(self, logits, y_true, delta, result):
"""Test the call method to ensure it returns the correct value
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexHuber(0.00001, 1, 1, delta)
@ -310,6 +359,13 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexHuber(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
@ -322,4 +378,4 @@ class HuberTests(keras_parameterized.TestCase):
if __name__ == '__main__':
tf.test.main()
tf.test.main()

View file

@ -21,12 +21,12 @@ from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers
from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton.optimizer import Private
from privacy.bolton.optimizer import Bolton
_accepted_distributions = ['laplace']
class Bolton(Model):
class BoltonModel(Model):
"""
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees:
@ -42,8 +42,7 @@ class Bolton(Model):
def __init__(self,
n_classes,
epsilon,
noise_distribution='laplace',
# noise_distribution='laplace',
seed=1,
dtype=tf.float32
):
@ -58,47 +57,22 @@ class Bolton(Model):
dtype: data type to use for tensors
"""
class MyCustomCallback(tf.keras.callbacks.Callback):
"""Custom callback for bolton training requirements.
Implements steps (see Bolton class):
2. Projects weights to R after each batch
3. Limits learning rate
"""
def on_train_batch_end(self, batch, logs=None):
loss = self.model.loss
self.model.optimizer.limit_learning_rate(
self.model.run_eagerly,
loss.beta(self.model.class_weight),
loss.gamma()
)
self.model._project_weights_to_r(loss.radius(), False)
def on_train_end(self, logs=None):
loss = self.model.loss
self.model._project_weights_to_r(loss.radius(), True)
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
super(Bolton, self).__init__(name='bolton', dynamic=False)
# if noise_distribution not in _accepted_distributions:
# raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
# 'distributions'.format(noise_distribution,
# _accepted_distributions))
# if epsilon <= 0:
# raise ValueError('Detected epsilon: {0}. '
# 'Valid range is 0 < epsilon <inf'.format(epsilon))
# self.epsilon = epsilon
super(BoltonModel, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
# if we do regularization here, we require the user to re-instantiate
# the model each time they want to
# change lambda, unless we standardize modifying it later at .compile
self.force = False
self.noise_distribution = noise_distribution
self.epsilon = epsilon
# self.noise_distribution = noise_distribution
self.seed = seed
self.__in_fit = False
self._layers_instantiated = False
self._callback = MyCustomCallback()
# self._callback = MyCustomCallback()
self._dtype = dtype
def call(self, inputs):
@ -139,55 +113,65 @@ class Bolton(Model):
kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_intiializer(),
)
# if we don't do regularization here, we require the user to
# re-instantiate the model each time they want to change the penalty
# weighting
self._layers_instantiated = True
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
if not isinstance(optimizer, Private):
if not isinstance(optimizer, Bolton):
optimizer = optimizers.get(optimizer)
optimizer = Private(optimizer)
optimizer = Bolton(optimizer, loss)
super(Bolton, self).compile(optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs
)
super(BoltonModel, self).compile(optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics,
target_tensors=target_tensors,
distribute=distribute,
**kwargs
)
def _post_fit(self, x, n_samples):
"""Implements 1-time weight changes needed for Bolton method.
In this case, specifically implements the noise addition
assuming a strongly convex function.
Args:
x: inputs
n_samples: number of samples in the inputs. In case the number
cannot be readily determined by inspecting x.
Returns:
"""
data_size = None
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, "__len__"):
data_size = len(x)
elif data_size is None:
if n_samples is None:
raise ValueError("Unable to detect the number of training "
"samples and n_smaples was None. "
"either pass a dataset with a .shape or "
"__len__ attribute or explicitly pass the "
"number of samples as n_smaples.")
for layer in self._layers:
layer.kernel = layer.kernel + self._get_noise(
self.noise_distribution,
data_size
)
# def _post_fit(self, x, n_samples):
# """Implements 1-time weight changes needed for Bolton method.
# In this case, specifically implements the noise addition
# assuming a strongly convex function.
#
# Args:
# x: inputs
# n_samples: number of samples in the inputs. In case the number
# cannot be readily determined by inspecting x.
#
# Returns:
#
# """
# data_size = None
# if n_samples is not None:
# data_size = n_samples
# elif hasattr(x, 'shape'):
# data_size = x.shape[0]
# elif hasattr(x, "__len__"):
# data_size = len(x)
# elif data_size is None:
# if n_samples is None:
# raise ValueError("Unable to detect the number of training "
# "samples and n_smaples was None. "
# "either pass a dataset with a .shape or "
# "__len__ attribute or explicitly pass the "
# "number of samples as n_smaples.")
# for layer in self.layers:
# # layer.kernel = layer.kernel + self._get_noise(
# # data_size
# # )
# input_dim = layer.kernel.numpy().shape[0]
# layer.kernel = layer.kernel + self.optimizer.get_noise(
# self.loss,
# data_size,
# input_dim,
# self.n_classes,
# self.class_weight
# )
def fit(self,
x=None,
@ -209,6 +193,8 @@ class Bolton(Model):
workers=1,
use_multiprocessing=False,
n_samples=None,
epsilon=2,
noise_distribution='laplace',
**kwargs):
"""Reroutes to super fit with additional Bolton delta-epsilon privacy
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
@ -226,35 +212,40 @@ class Bolton(Model):
"""
self.__in_fit = True
cb = [self._callback]
if callbacks is not None:
cb.extend(callbacks)
callbacks = cb
# cb = [self.optimizer.callbacks]
# if callbacks is not None:
# cb.extend(callbacks)
# callbacks = cb
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit(x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
self._post_fit(x, n_samples)
self.__in_fit = False
# self.class_weight = class_weight
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight,
n_samples,
self.n_classes,
) as optim:
out = super(BoltonModel, self).fit(x=x,
y=y,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
validation_split=validation_split,
validation_data=validation_data,
shuffle=shuffle,
class_weight=class_weight,
sample_weight=sample_weight,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
return out
def fit_generator(self,
@ -284,7 +275,7 @@ class Bolton(Model):
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight
out = super(Bolton, self).fit_generator(
out = super(BoltonModel, self).fit_generator(
generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
@ -366,66 +357,195 @@ class Bolton(Model):
"1D array".format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
"Detected array length: {0} instead of: {1}".format(
class_weights.shape[0],
num_classes
)
)
return class_weights
def _project_weights_to_r(self, r, force=False):
"""helper method to normalize the weights to the R-ball.
# def _project_weights_to_r(self, r, force=False):
# """helper method to normalize the weights to the R-ball.
#
# Args:
# r: radius of "R-Ball". Scalar to normalize to.
# force: True to normalize regardless of previous weight values.
# False to check if weights > R-ball and only normalize then.
#
# Returns:
#
# """
# for layer in self.layers:
# weight_norm = tf.norm(layer.kernel, axis=0)
# if force:
# layer.kernel = layer.kernel / (weight_norm / r)
# elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
# layer.kernel = layer.kernel / (weight_norm / r)
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
# def _get_noise(self, distribution, data_size):
# """Sample noise to be added to weights for privacy guarantee
#
# Args:
# distribution: the distribution type to pull noise from
# data_size: the number of samples
#
# Returns: noise in shape of layer's weights to be added to the weights.
#
# """
# distribution = distribution.lower()
# input_dim = self.layers[0].kernel.numpy().shape[0]
# loss = self.loss
# if distribution == _accepted_distributions[0]: # laplace
# per_class_epsilon = self.epsilon / (self.n_classes)
# l2_sensitivity = (2 *
# loss.lipchitz_constant(self.class_weight)) / \
# (loss.gamma() * data_size)
# unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
# mean=0,
# seed=1,
# stddev=1.0,
# dtype=self._dtype)
# unit_vector = unit_vector / tf.math.sqrt(
# tf.reduce_sum(tf.math.square(unit_vector), axis=0)
# )
#
# beta = l2_sensitivity / per_class_epsilon
# alpha = input_dim # input_dim
# gamma = tf.random.gamma([self.n_classes],
# alpha,
# beta=1 / beta,
# seed=1,
# dtype=self._dtype
# )
# return unit_vector * gamma
# raise NotImplementedError('Noise distribution: {0} is not '
# 'a valid distribution'.format(distribution))
Returns:
"""
for layer in self._layers:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
if __name__ == '__main__':
import tensorflow as tf
def _get_noise(self, distribution, data_size):
"""Sample noise to be added to weights for privacy guarantee
import os
import time
import matplotlib.pyplot as plt
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'
Returns: noise in shape of layer's weights to be added to the weights.
path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
origin=_URL,
extract=True)
"""
distribution = distribution.lower()
input_dim = self._layers[0].kernel.numpy().shape[0]
loss = self.loss
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (self.n_classes)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
mean=0,
seed=1,
stddev=1.0,
dtype=self._dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([self.n_classes],
alpha,
beta=1 / beta,
seed=1,
dtype=self._dtype
)
return unit_vector * gamma
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def load(image_file):
image = tf.io.read_file(image_file)
image = tf.image.decode_jpeg(image)
w = tf.shape(image)[1]
w = w // 2
real_image = image[:, :w, :]
input_image = image[:, w:, :]
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
return input_image, real_image
inp, re = load(PATH + 'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(re / 255.0)
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
def normalize(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
# resizing to 286 x 286 x 3
input_image, real_image = resize(input_image, real_image, 286, 286)
# randomly cropping to 256 x 256 x 3
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
# random mirroring
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image
def load_image_train(image_file):
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
def load_image_test(image_file):
input_image, real_image = load(image_file)
input_image, real_image = resize(input_image, real_image,
IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)
# steps_per_epoch = training_utils.infer_steps_for_dataset(
# train_dataset, None, epochs=1, steps_name='steps')
# for batch in train_dataset:
# print(batch[1].shape)
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)
be = BoltonModel(3, 2)
from tensorflow.python.keras.optimizer_v2 import adam
from privacy.bolton import loss
test = adam.Adam()
l = loss.StrongConvexBinaryCrossentropy(1, 2, 1)
be.compile(test, l)
print("Eager exeuction: {0}".format(tf.executing_eagerly()))
be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1)

View file

@ -32,7 +32,7 @@ from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
class TestLoss(losses.Loss):
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
@ -145,21 +145,25 @@ class InitTests(keras_parameterized.TestCase):
self.assertIsInstance(clf, model.Bolton)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'not_valid',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid epsilon',
'n_classes': 1,
'epsilon': -1,
'noise_distribution': 'laplace',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid noise',
'n_classes': 1,
'epsilon': 1,
'noise_distribution': 'not_valid',
'weights_initializer': tf.initializers.GlorotUniform(),
},
{'testcase_name': 'invalid epsilon',
'n_classes': 1,
'epsilon': -1,
'noise_distribution': 'laplace',
'weights_initializer': tf.initializers.GlorotUniform(),
},
])
def test_bad_init_params(
self, n_classes, epsilon, noise_distribution, weights_initializer):
self,
n_classes,
epsilon,
noise_distribution,
weights_initializer):
# test invalid domains for each variable, especially noise
seed = 1
with self.assertRaises(ValueError):
@ -204,16 +208,16 @@ class InitTests(keras_parameterized.TestCase):
self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([
{'testcase_name': 'Not strong loss',
'n_classes': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
{'testcase_name': 'Not strong loss',
'n_classes': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_classes': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
])
def test_bad_compile(self, n_classes, loss, optimizer):
# test compilaton of invalid tf.optimizer and non instantiated loss.
@ -250,7 +254,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
y_stack = []
for i_class in range(n_classes):
x_stack.append(
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
)
y_stack.append(
tf.constant(i_class, tf.float32, (n_samples, n_classes))
@ -258,7 +262,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
if generator:
dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set)
(x_set, y_set)
)
return dataset
return x_set, y_set
@ -281,10 +285,10 @@ def _do_fit(n_samples,
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
generator=generator
n_samples,
input_dim,
n_classes,
generator=generator
)
y = None
# x = x.batch(batch_size)
@ -315,26 +319,26 @@ class FitTests(keras_parameterized.TestCase):
# @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
'callbacks': TestCallback()
},
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
'callbacks': None
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
'callbacks': None
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
'callbacks': TestCallback()
},
])
def test_fit(self, generator, reset_n_samples, callbacks):
loss = TestLoss(1, 1, 1)
@ -344,9 +348,19 @@ class FitTests(keras_parameterized.TestCase):
epsilon = 1
batch_size = 1
n_samples = 10
clf = _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
reset_n_samples, optimizer, loss, callbacks)
self.assertEqual(hasattr(clf, '_layers'), True)
clf = _do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
callbacks
)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'generator fit',
@ -368,15 +382,15 @@ class FitTests(keras_parameterized.TestCase):
)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
generator=generator
n_samples,
input_dim,
n_classes,
generator=generator
)
x = x.batch(batch_size)
x = x.shuffle(n_samples // 2)
clf.fit_generator(x, n_samples=n_samples)
self.assertEqual(hasattr(clf, '_layers'), True)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'iterator no n_samples',
@ -399,32 +413,43 @@ class FitTests(keras_parameterized.TestCase):
epsilon = 1
batch_size = 1
n_samples = 10
_do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size,
reset_n_samples, optimizer, loss, None, distribution)
_do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
None,
distribution
)
@parameterized.named_parameters([
{'testcase_name': 'None class_weights',
'class_weights': None,
'class_counts': None,
'num_classes': None,
'result': 1},
{'testcase_name': 'class weights array',
'class_weights': [1, 1],
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'class weights balanced',
'class_weights': 'balanced',
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'None class_weights',
'class_weights': None,
'class_counts': None,
'num_classes': None,
'result': 1},
{'testcase_name': 'class weights array',
'class_weights': [1, 1],
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'class weights balanced',
'class_weights': 'balanced',
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
])
def test_class_calculate(self,
class_weights,
class_counts,
num_classes,
result
):
):
clf = model.Bolton(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
@ -447,14 +472,14 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': 'balanced',
'class_counts': None,
'num_classes': 1,
'err_msg':
"Class counts must be provided if using class_weights=balanced"},
'err_msg': "Class counts must be provided if "
"using class_weights=balanced"},
{'testcase_name': 'no num classes',
'class_weights': 'balanced',
'class_counts': [1],
'num_classes': None,
'err_msg':
'num_classes must be provided if using class_weights=balanced'},
'err_msg': 'num_classes must be provided if '
'using class_weights=balanced'},
{'testcase_name': 'class counts not array',
'class_weights': 'balanced',
'class_counts': 1,
@ -464,7 +489,7 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': [1],
'class_counts': None,
'num_classes': None,
'err_msg': "You must pass a value for num_classes if"
'err_msg': "You must pass a value for num_classes if "
"creating an array of class_weights"},
{'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]],
@ -481,7 +506,8 @@ class FitTests(keras_parameterized.TestCase):
class_weights,
class_counts,
num_classes,
err_msg):
err_msg
):
clf = model.Bolton(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg):
expected = clf.calculate_class_weights(class_weights,

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Private Optimizer for bolton method"""
"""Bolton Optimizer for bolton method"""
from __future__ import absolute_import
from __future__ import division
@ -19,15 +19,16 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from privacy.bolton.loss import StrongConvexMixin
_private_attributes = ['_internal_optimizer', 'dtype']
_accepted_distributions = ['laplace']
class Private(optimizer_v2.OptimizerV2):
class Bolton(optimizer_v2.OptimizerV2):
"""
Private optimizer wraps another tf optimizer to be used
Bolton optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "Private" enables the bolton model to control the learning rate
passed, "Bolton" enables the bolton model to control the learning rate
based on the strongly convex loss.
For more details on the strong convexity requirements, see:
@ -36,7 +37,8 @@ class Private(optimizer_v2.OptimizerV2):
"""
def __init__(self,
optimizer: optimizer_v2.OptimizerV2,
dtype=tf.float32
loss: StrongConvexMixin,
dtype=tf.float32,
):
"""Constructor.
@ -44,15 +46,100 @@ class Private(optimizer_v2.OptimizerV2):
optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped).
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError("loss function must be a Strongly Convex and therfore"
"extend the StrongConvexMixin.")
self._private_attributes = ['_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'n_classes',
'layers',
'_model'
]
self._internal_optimizer = optimizer
self.dtype = dtype
self.loss = loss
def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.get_config()
def limit_learning_rate(self, is_eager, beta, gamma):
def project_weights_to_r(self, force=False):
"""helper method to normalize the weights to the R-ball.
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Returns:
"""
r = self.loss.radius()
for layer in self.layers:
if tf.executing_eagerly():
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
else:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
else:
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0,
lambda: layer.kernel / (weight_norm / r),
lambda: layer.kernel
)
def get_noise(self, data_size, input_dim, output_dim, class_weight):
"""Sample noise to be added to weights for privacy guarantee
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
Returns: noise in shape of layer's weights to be added to the weights.
"""
loss = self.loss
distribution = self.noise_distribution.lower()
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (output_dim)
l2_sensitivity = (2 *
loss.lipchitz_constant(class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
mean=0,
seed=1,
stddev=1.0,
dtype=self.dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([output_dim],
alpha,
beta=1 / beta,
seed=1,
dtype=self.dtype
)
return unit_vector * gamma
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def limit_learning_rate(self, beta, gamma):
"""Implements learning rate limitation that is required by the bolton
method for sensitivity bounding of the strongly convex function.
Sets the learning rate to the min(1/beta, 1/(gamma*t))
@ -65,20 +152,13 @@ class Private(optimizer_v2.OptimizerV2):
Returns: None
"""
numerator = tf.Variable(initial_value=1, dtype=self.dtype)
numerator = tf.constant(1, dtype=self.dtype)
t = tf.cast(self._iterations, self.dtype)
# will exist on the internal optimizer
pred = numerator / beta < numerator / (gamma * t)
if is_eager: # check eagerly
if pred:
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
if numerator / beta < numerator / (gamma * t):
self.learning_rate = numerator / beta
else:
if pred:
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
self.learning_rate = numerator / (gamma * t)
def from_config(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
@ -92,14 +172,25 @@ class Private(optimizer_v2.OptimizerV2):
Args:
name:
Returns: attribute from Private if specified to come from self, else
Returns: attribute from Bolton if specified to come from self, else
from _internal_optimizer.
"""
if name in _private_attributes:
if name == '_private_attributes':
return getattr(self, name)
elif name in self._private_attributes:
return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer')
return object.__getattribute__(optim, name)
try:
return object.__getattribute__(optim, name)
except AttributeError:
raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(
self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
)
def __setattr__(self, key, value):
""" Set attribute to self instance if its the internal optimizer.
@ -112,7 +203,9 @@ class Private(optimizer_v2.OptimizerV2):
Returns:
"""
if key in _private_attributes:
if key == '_private_attributes':
object.__setattr__(self, key, value)
elif key in key in self._private_attributes:
object.__setattr__(self, key, value)
else:
setattr(self._internal_optimizer, key, value)
@ -130,24 +223,135 @@ class Private(optimizer_v2.OptimizerV2):
def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.get_updates(loss, params)
# self.layers = params
out = self._internal_optimizer.get_updates(loss, params)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def apply_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.apply_gradients(*args, **kwargs)
# grads_and_vars = kwargs.get('grads_and_vars', None)
# grads_and_vars = optimizer_v2._filter_grads(grads_and_vars)
# var_list = [v for (_, v) in grads_and_vars]
# self.layers = var_list
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def minimize(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
return self._internal_optimizer.minimize(*args, **kwargs)
# self.layers = kwargs.get('var_list', None)
out = self._internal_optimizer.minimize(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def _compute_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = kwargs.get('var_list', None)
return self._internal_optimizer._compute_gradients(*args, **kwargs)
def get_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer.
"""
# self.layers = kwargs.get('params', None)
return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self):
noise_distribution = self.noise_distribution
epsilon = self.epsilon
class_weights = self.class_weights
n_samples = self.n_samples
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
return self
def __call__(self,
noise_distribution,
epsilon,
layers,
class_weights,
n_samples,
n_classes,
):
"""
Args:
noise_distribution: the noise distribution to pick.
see _accepted_distributions and get_noise for
possible values.
epsilon: privacy parameter. Lower gives more privacy but less utility.
class_weights: class_weights used
n_samples number of rows/individual samples in the training set
n_classes: number of output classes
layers: list of Keras/Tensorflow layers.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
self.n_classes = n_classes
self.layers = layers
return self
def __exit__(self, *args):
"""Exit call from with statement.
used to
1.reset the model and fit parameters passed to the optimizer
to enable the Bolton Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
"""
# for param in self.layers:
# if param.name.find('kernel') != -1 or param.name.find('weight') != -1:
# input_dim = param.numpy().shape[0]
# print(param)
# noise = -1 * self.get_noise(self.n_samples,
# input_dim,
# self.n_classes,
# self.class_weights
# )
# print(tf.math.subtract(param, noise))
# param.assign(tf.math.subtract(param, noise))
self.project_weights_to_r(True)
for layer in self.layers:
input_dim, output_dim = layer.kernel.shape
noise = self.get_noise(self.n_samples,
input_dim,
output_dim,
self.class_weights
)
layer.kernel = tf.math.add(layer.kernel, noise)
self.noise_distribution = None
self.epsilon = -1
self.class_weights = None
self.n_samples = None
self.input_dim = None
self.n_classes = None
self.layers = None

View file

@ -21,19 +21,129 @@ import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import keras_parameterized
from privacy.bolton import model
from privacy.bolton import optimizer as opt
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras import losses
from tensorflow.python.keras.models import Model
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import test_util
from absl.testing import parameterized
from absl.testing import absltest
from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton import optimizer as opt
class TestModel(Model):
"""
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, n_classes=2):
"""
Args:
n_classes: number of output classes to predict.
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
seed: random seed to use
dtype: data type to use for tensors
"""
super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.layer_input_shape = (16, 1)
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer='glorot_uniform',
)
# def call(self, inputs):
# """Forward pass of network
#
# Args:
# inputs: inputs to neural network
#
# Returns:
#
# """
# return self.output_layer(inputs)
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = C
self.radius_constant = radius_constant
def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch)
Returns: radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self):
""" Gamma strongly convex
Returns: gamma
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight):
"""Beta smoothess
Args:
class_weight: the class weights used.
Returns: Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight):
""" L lipchitz continuous
Args:
class_weight: class weights used
Returns: L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
def max_class_weight(self, class_weight):
if class_weight is None:
return 1
def kernel_regularizer(self):
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the Private optimizer"""
"""Optimizer used for testing the Bolton optimizer"""
def __init__(self):
super(TestOptimizer, self).__init__('test')
self.not_private = 'test'
self.iterations = tf.Variable(1, dtype=tf.float32)
self._iterations = tf.Variable(1, dtype=tf.float32)
self.iterations = tf.constant(1, dtype=tf.float32)
self._iterations = tf.constant(1, dtype=tf.float32)
def _compute_gradients(self, loss, var_list, grad_loss=None):
return 'test'
@ -41,7 +151,7 @@ class TestOptimizer(OptimizerV2):
def get_config(self):
return 'test'
def from_config(cls, config, custom_objects=None):
def from_config(self, config, custom_objects=None):
return 'test'
def _create_slots(self):
@ -65,34 +175,22 @@ class TestOptimizer(OptimizerV2):
def get_gradients(self, loss, params):
return 'test'
class PrivateTest(keras_parameterized.TestCase):
"""Private Optimizer tests"""
def limit_learning_rate(self):
return 'test'
class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Bolton Optimizer tests"""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'branch True, beta',
{'testcase_name': 'branch beta',
'fn': 'limit_learning_rate',
'args': [True,
tf.Variable(2, dtype=tf.float32),
'args': [tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch True, gamma',
{'testcase_name': 'branch gamma',
'fn': 'limit_learning_rate',
'args': [True,
tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, beta',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, gamma',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(1, dtype=tf.float32),
'args': [tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
@ -101,9 +199,26 @@ class PrivateTest(keras_parameterized.TestCase):
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
{'testcase_name': 'project_weights_to_r',
'fn': 'project_weights_to_r',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
])
def test_fn(self, fn, args, result, test_attr):
private = opt.Private(TestOptimizer())
"""test that a fn of Bolton optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of Bolton to check against result with.
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
private = opt.Bolton(TestOptimizer(), loss)
res = getattr(private, fn, None)(*args)
if test_attr is not None:
res = getattr(private, test_attr, None)
@ -142,41 +257,88 @@ class PrivateTest(keras_parameterized.TestCase):
'args': [1, 1]},
])
def test_rerouted_function(self, fn, args):
""" tests that a method of the internal optimizer is correctly routed from
the Bolton instance to the internal optimizer instance (TestOptimizer,
here).
Args:
fn: fn to test
args: arguments to that fn
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
optimizer = opt.Private(optimizer)
self.assertEqual(
getattr(optimizer, fn, lambda: 'fn not found')(*args),
'test'
)
optimizer = opt.Bolton(optimizer, loss)
model = TestModel(2)
model.compile(optimizer, loss)
model.layers[0].kernel_initializer(model.layer_input_shape)
print(model.layers[0].__dict__)
with optimizer('laplace', 2, model.layers, 1, 1, model.n_classes):
self.assertEqual(
getattr(optimizer, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([
{'testcase_name': 'fn: limit_learning_rate',
'fn': 'limit_learning_rate',
'args': [1, 1, 1]}
'args': [1, 1, 1]},
{'testcase_name': 'fn: project_weights_to_r',
'fn': 'project_weights_to_r',
'args': []},
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1, 1, 1]},
])
def test_not_reroute_fn(self, fn, args):
"""Test that a fn that should not be rerouted to the internal optimizer is
in face not rerouted.
Args:
fn: fn to test
args: arguments to that fn
"""
optimizer = TestOptimizer()
optimizer = opt.Private(optimizer)
loss = TestLoss(1, 1, 1)
optimizer = opt.Bolton(optimizer, loss)
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
'test')
@parameterized.named_parameters([
{'testcase_name': 'attr: not_private',
'attr': 'not_private'}
{'testcase_name': 'attr: _iterations',
'attr': '_iterations'}
])
def test_reroute_attr(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
optimizer = opt.Bolton(internal_optimizer, loss)
self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr)
)
@parameterized.named_parameters([
{'testcase_name': 'attr: _internal_optimizer',
'attr': '_internal_optimizer'}
{'testcase_name': 'attr does not exist',
'attr': '_not_valid'}
])
def test_not_reroute_attr(self, attr):
def test_attribute_error(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer)
optimizer = opt.Bolton(internal_optimizer, loss)
with self.assertRaises(AttributeError):
getattr(optimizer, attr)
if __name__ == '__main__':
test.main()
test.main()