Working bolton model without unit tests.

-- moving to Bolton Optimizer
Model is now just a convenient wrapper and example for users.
Optimizer holds ALL Bolton privacy requirements.
Optimizer is used as a context manager, and must be passed the model's layers.
Unit tests incomplete, committing for visibility into the design.
This commit is contained in:
Christopher Choquette Choo 2019-06-13 01:01:31 -04:00
parent 751eaead54
commit ec18db5ec5
5 changed files with 913 additions and 345 deletions

View file

@ -18,23 +18,17 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.keras.optimizer_v2 import adagrad
from tensorflow.python.keras.optimizer_v2 import gradient_descent
from tensorflow.python.keras import losses
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from privacy.bolton import model from tensorflow.python.keras.regularizers import L1L2
from absl.testing import parameterized
from privacy.bolton.loss import StrongConvexBinaryCrossentropy from privacy.bolton.loss import StrongConvexBinaryCrossentropy
from privacy.bolton.loss import StrongConvexHuber from privacy.bolton.loss import StrongConvexHuber
from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.loss import StrongConvexMixin
from absl.testing import parameterized
from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2
class StrongConvexTests(keras_parameterized.TestCase): class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin"""
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'beta not implemented', {'testcase_name': 'beta not implemented',
'fn': 'beta', 'fn': 'beta',
@ -50,6 +44,12 @@ class StrongConvexTests(keras_parameterized.TestCase):
'args': []}, 'args': []},
]) ])
def test_not_implemented(self, fn, args): def test_not_implemented(self, fn, args):
"""Test that the given fn's are not implemented on the mixin.
Args:
fn: fn on Mixin to test
args: arguments to fn of Mixin
"""
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
loss = StrongConvexMixin() loss = StrongConvexMixin()
getattr(loss, fn, None)(*args) getattr(loss, fn, None)(*args)
@ -60,6 +60,12 @@ class StrongConvexTests(keras_parameterized.TestCase):
'args': []}, 'args': []},
]) ])
def test_return_none(self, fn, args): def test_return_none(self, fn, args):
"""Test that fn of Mixin returns None
Args:
fn: fn of Mixin to test
args: arguments to fn of Mixin
"""
loss = StrongConvexMixin() loss = StrongConvexMixin()
ret = getattr(loss, fn, None)(*args) ret = getattr(loss, fn, None)(*args)
self.assertEqual(ret, None) self.assertEqual(ret, None)
@ -71,44 +77,56 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'normal', {'testcase_name': 'normal',
'reg_lambda': 1, 'reg_lambda': 1,
'c': 1, 'C': 1,
'radius_constant': 1 'radius_constant': 1
}, },
]) ])
def test_init_params(self, reg_lambda, c, radius_constant): def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable # test valid domains for each variable
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'negative c', {'testcase_name': 'negative c',
'reg_lambda': 1, 'reg_lambda': 1,
'c': -1, 'C': -1,
'radius_constant': 1 'radius_constant': 1
}, },
{'testcase_name': 'negative radius', {'testcase_name': 'negative radius',
'reg_lambda': 1, 'reg_lambda': 1,
'c': 1, 'C': 1,
'radius_constant': -1 'radius_constant': -1
}, },
{'testcase_name': 'negative lambda', {'testcase_name': 'negative lambda',
'reg_lambda': -1, 'reg_lambda': -1,
'c': 1, 'C': 1,
'radius_constant': 1 'radius_constant': 1
}, },
]) ])
def test_bad_init_params(self, reg_lambda, c, radius_constant): def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable # test valid domains for each variable
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation # [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive', {'testcase_name': 'both positive',
'logits': [10000], 'logits': [10000],
'y_true': [1], 'y_true': [1],
'result': 0, 'result': 0,
}, },
{'testcase_name': 'positive gradient negative logits', {'testcase_name': 'positive gradient negative logits',
'logits': [-10000], 'logits': [-10000],
@ -127,6 +145,12 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, },
]) ])
def test_calculation(self, logits, y_true, result): def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32) logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32) y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
@ -160,6 +184,13 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, },
]) ])
def test_fns(self, init_args, fn, args, result): def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexBinaryCrossentropy(*init_args) loss = StrongConvexBinaryCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args) expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
@ -183,6 +214,12 @@ class HuberTests(keras_parameterized.TestCase):
}, },
]) ])
def test_init_params(self, reg_lambda, c, radius_constant, delta): def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable # test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
self.assertIsInstance(loss, StrongConvexHuber) self.assertIsInstance(loss, StrongConvexHuber)
@ -214,18 +251,24 @@ class HuberTests(keras_parameterized.TestCase):
}, },
]) ])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable # test valid domains for each variable
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) StrongConvexHuber(reg_lambda, c, radius_constant, delta)
# test the bounds and test varied delta's # test the bounds and test varied delta's
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
'logits': 2.1, 'logits': 2.1,
'y_true': 1, 'y_true': 1,
'delta': 1, 'delta': 1,
'result': 0, 'result': 0,
}, },
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
'logits': 1.9, 'logits': 1.9,
@ -277,6 +320,12 @@ class HuberTests(keras_parameterized.TestCase):
}, },
]) ])
def test_calculation(self, logits, y_true, delta, result): def test_calculation(self, logits, y_true, delta, result):
"""Test the call method to ensure it returns the correct value
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32) logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32) y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexHuber(0.00001, 1, 1, delta) loss = StrongConvexHuber(0.00001, 1, 1, delta)
@ -310,6 +359,13 @@ class HuberTests(keras_parameterized.TestCase):
}, },
]) ])
def test_fns(self, init_args, fn, args, result): def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexHuber(*init_args) loss = StrongConvexHuber(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args) expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
@ -322,4 +378,4 @@ class HuberTests(keras_parameterized.TestCase):
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View file

@ -21,12 +21,12 @@ from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers from tensorflow.python.keras import optimizers
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton.optimizer import Private from privacy.bolton.optimizer import Bolton
_accepted_distributions = ['laplace'] _accepted_distributions = ['laplace']
class Bolton(Model): class BoltonModel(Model):
""" """
Bolton episilon-delta model Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees: Uses 4 key steps to achieve privacy guarantees:
@ -42,8 +42,7 @@ class Bolton(Model):
def __init__(self, def __init__(self,
n_classes, n_classes,
epsilon, # noise_distribution='laplace',
noise_distribution='laplace',
seed=1, seed=1,
dtype=tf.float32 dtype=tf.float32
): ):
@ -58,47 +57,22 @@ class Bolton(Model):
dtype: data type to use for tensors dtype: data type to use for tensors
""" """
class MyCustomCallback(tf.keras.callbacks.Callback): # if noise_distribution not in _accepted_distributions:
"""Custom callback for bolton training requirements. # raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
Implements steps (see Bolton class): # 'distributions'.format(noise_distribution,
2. Projects weights to R after each batch # _accepted_distributions))
3. Limits learning rate # if epsilon <= 0:
""" # raise ValueError('Detected epsilon: {0}. '
# 'Valid range is 0 < epsilon <inf'.format(epsilon))
def on_train_batch_end(self, batch, logs=None): # self.epsilon = epsilon
loss = self.model.loss super(BoltonModel, self).__init__(name='bolton', dynamic=False)
self.model.optimizer.limit_learning_rate(
self.model.run_eagerly,
loss.beta(self.model.class_weight),
loss.gamma()
)
self.model._project_weights_to_r(loss.radius(), False)
def on_train_end(self, logs=None):
loss = self.model.loss
self.model._project_weights_to_r(loss.radius(), True)
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
super(Bolton, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes self.n_classes = n_classes
# if we do regularization here, we require the user to re-instantiate
# the model each time they want to
# change lambda, unless we standardize modifying it later at .compile
self.force = False self.force = False
self.noise_distribution = noise_distribution # self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.seed = seed self.seed = seed
self.__in_fit = False self.__in_fit = False
self._layers_instantiated = False self._layers_instantiated = False
self._callback = MyCustomCallback() # self._callback = MyCustomCallback()
self._dtype = dtype self._dtype = dtype
def call(self, inputs): def call(self, inputs):
@ -139,55 +113,65 @@ class Bolton(Model):
kernel_regularizer=loss.kernel_regularizer(), kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_intiializer(), kernel_initializer=kernel_intiializer(),
) )
# if we don't do regularization here, we require the user to
# re-instantiate the model each time they want to change the penalty
# weighting
self._layers_instantiated = True self._layers_instantiated = True
self.output_layer.kernel_regularizer.l2 = loss.reg_lambda self.output_layer.kernel_regularizer.l2 = loss.reg_lambda
if not isinstance(optimizer, Private): if not isinstance(optimizer, Bolton):
optimizer = optimizers.get(optimizer) optimizer = optimizers.get(optimizer)
optimizer = Private(optimizer) optimizer = Bolton(optimizer, loss)
super(Bolton, self).compile(optimizer, super(BoltonModel, self).compile(optimizer,
loss=loss, loss=loss,
metrics=metrics, metrics=metrics,
loss_weights=loss_weights, loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode, sample_weight_mode=sample_weight_mode,
weighted_metrics=weighted_metrics, weighted_metrics=weighted_metrics,
target_tensors=target_tensors, target_tensors=target_tensors,
distribute=distribute, distribute=distribute,
**kwargs **kwargs
) )
def _post_fit(self, x, n_samples): # def _post_fit(self, x, n_samples):
"""Implements 1-time weight changes needed for Bolton method. # """Implements 1-time weight changes needed for Bolton method.
In this case, specifically implements the noise addition # In this case, specifically implements the noise addition
assuming a strongly convex function. # assuming a strongly convex function.
#
Args: # Args:
x: inputs # x: inputs
n_samples: number of samples in the inputs. In case the number # n_samples: number of samples in the inputs. In case the number
cannot be readily determined by inspecting x. # cannot be readily determined by inspecting x.
#
Returns: # Returns:
#
""" # """
data_size = None # data_size = None
if n_samples is not None: # if n_samples is not None:
data_size = n_samples # data_size = n_samples
elif hasattr(x, 'shape'): # elif hasattr(x, 'shape'):
data_size = x.shape[0] # data_size = x.shape[0]
elif hasattr(x, "__len__"): # elif hasattr(x, "__len__"):
data_size = len(x) # data_size = len(x)
elif data_size is None: # elif data_size is None:
if n_samples is None: # if n_samples is None:
raise ValueError("Unable to detect the number of training " # raise ValueError("Unable to detect the number of training "
"samples and n_smaples was None. " # "samples and n_smaples was None. "
"either pass a dataset with a .shape or " # "either pass a dataset with a .shape or "
"__len__ attribute or explicitly pass the " # "__len__ attribute or explicitly pass the "
"number of samples as n_smaples.") # "number of samples as n_smaples.")
for layer in self._layers: # for layer in self.layers:
layer.kernel = layer.kernel + self._get_noise( # # layer.kernel = layer.kernel + self._get_noise(
self.noise_distribution, # # data_size
data_size # # )
) # input_dim = layer.kernel.numpy().shape[0]
# layer.kernel = layer.kernel + self.optimizer.get_noise(
# self.loss,
# data_size,
# input_dim,
# self.n_classes,
# self.class_weight
# )
def fit(self, def fit(self,
x=None, x=None,
@ -209,6 +193,8 @@ class Bolton(Model):
workers=1, workers=1,
use_multiprocessing=False, use_multiprocessing=False,
n_samples=None, n_samples=None,
epsilon=2,
noise_distribution='laplace',
**kwargs): **kwargs):
"""Reroutes to super fit with additional Bolton delta-epsilon privacy """Reroutes to super fit with additional Bolton delta-epsilon privacy
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
@ -226,35 +212,40 @@ class Bolton(Model):
""" """
self.__in_fit = True self.__in_fit = True
cb = [self._callback] # cb = [self.optimizer.callbacks]
if callbacks is not None: # if callbacks is not None:
cb.extend(callbacks) # cb.extend(callbacks)
callbacks = cb # callbacks = cb
if class_weight is None: if class_weight is None:
class_weight = self.calculate_class_weights(class_weight) class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight # self.class_weight = class_weight
out = super(Bolton, self).fit(x=x, with self.optimizer(noise_distribution,
y=y, epsilon,
batch_size=batch_size, self.layers,
epochs=epochs, class_weight,
verbose=verbose, n_samples,
callbacks=callbacks, self.n_classes,
validation_split=validation_split, ) as optim:
validation_data=validation_data, out = super(BoltonModel, self).fit(x=x,
shuffle=shuffle, y=y,
class_weight=class_weight, batch_size=batch_size,
sample_weight=sample_weight, epochs=epochs,
initial_epoch=initial_epoch, verbose=verbose,
steps_per_epoch=steps_per_epoch, callbacks=callbacks,
validation_steps=validation_steps, validation_split=validation_split,
validation_freq=validation_freq, validation_data=validation_data,
max_queue_size=max_queue_size, shuffle=shuffle,
workers=workers, class_weight=class_weight,
use_multiprocessing=use_multiprocessing, sample_weight=sample_weight,
**kwargs initial_epoch=initial_epoch,
) steps_per_epoch=steps_per_epoch,
self._post_fit(x, n_samples) validation_steps=validation_steps,
self.__in_fit = False validation_freq=validation_freq,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
**kwargs
)
return out return out
def fit_generator(self, def fit_generator(self,
@ -284,7 +275,7 @@ class Bolton(Model):
if class_weight is None: if class_weight is None:
class_weight = self.calculate_class_weights(class_weight) class_weight = self.calculate_class_weights(class_weight)
self.class_weight = class_weight self.class_weight = class_weight
out = super(Bolton, self).fit_generator( out = super(BoltonModel, self).fit_generator(
generator, generator,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
epochs=epochs, epochs=epochs,
@ -366,66 +357,195 @@ class Bolton(Model):
"1D array".format(class_weights.shape)) "1D array".format(class_weights.shape))
if class_weights.shape[0] != num_classes: if class_weights.shape[0] != num_classes:
raise ValueError( raise ValueError(
"Detected array length: {0} instead of: {1}".format( "Detected array length: {0} instead of: {1}".format(
class_weights.shape[0], class_weights.shape[0],
num_classes num_classes
) )
) )
return class_weights return class_weights
def _project_weights_to_r(self, r, force=False): # def _project_weights_to_r(self, r, force=False):
"""helper method to normalize the weights to the R-ball. # """helper method to normalize the weights to the R-ball.
#
# Args:
# r: radius of "R-Ball". Scalar to normalize to.
# force: True to normalize regardless of previous weight values.
# False to check if weights > R-ball and only normalize then.
#
# Returns:
#
# """
# for layer in self.layers:
# weight_norm = tf.norm(layer.kernel, axis=0)
# if force:
# layer.kernel = layer.kernel / (weight_norm / r)
# elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
# layer.kernel = layer.kernel / (weight_norm / r)
Args: # def _get_noise(self, distribution, data_size):
r: radius of "R-Ball". Scalar to normalize to. # """Sample noise to be added to weights for privacy guarantee
force: True to normalize regardless of previous weight values. #
False to check if weights > R-ball and only normalize then. # Args:
# distribution: the distribution type to pull noise from
# data_size: the number of samples
#
# Returns: noise in shape of layer's weights to be added to the weights.
#
# """
# distribution = distribution.lower()
# input_dim = self.layers[0].kernel.numpy().shape[0]
# loss = self.loss
# if distribution == _accepted_distributions[0]: # laplace
# per_class_epsilon = self.epsilon / (self.n_classes)
# l2_sensitivity = (2 *
# loss.lipchitz_constant(self.class_weight)) / \
# (loss.gamma() * data_size)
# unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
# mean=0,
# seed=1,
# stddev=1.0,
# dtype=self._dtype)
# unit_vector = unit_vector / tf.math.sqrt(
# tf.reduce_sum(tf.math.square(unit_vector), axis=0)
# )
#
# beta = l2_sensitivity / per_class_epsilon
# alpha = input_dim # input_dim
# gamma = tf.random.gamma([self.n_classes],
# alpha,
# beta=1 / beta,
# seed=1,
# dtype=self._dtype
# )
# return unit_vector * gamma
# raise NotImplementedError('Noise distribution: {0} is not '
# 'a valid distribution'.format(distribution))
Returns:
""" if __name__ == '__main__':
for layer in self._layers: import tensorflow as tf
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self._dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
def _get_noise(self, distribution, data_size): import os
"""Sample noise to be added to weights for privacy guarantee import time
import matplotlib.pyplot as plt
Args: _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'
distribution: the distribution type to pull noise from
data_size: the number of samples
Returns: noise in shape of layer's weights to be added to the weights. path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
origin=_URL,
extract=True)
""" PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
distribution = distribution.lower() BUFFER_SIZE = 400
input_dim = self._layers[0].kernel.numpy().shape[0] BATCH_SIZE = 1
loss = self.loss IMG_WIDTH = 256
if distribution == _accepted_distributions[0]: # laplace IMG_HEIGHT = 256
per_class_epsilon = self.epsilon / (self.n_classes)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, self.n_classes),
mean=0,
seed=1,
stddev=1.0,
dtype=self._dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim def load(image_file):
gamma = tf.random.gamma([self.n_classes], image = tf.io.read_file(image_file)
alpha, image = tf.image.decode_jpeg(image)
beta=1 / beta,
seed=1, w = tf.shape(image)[1]
dtype=self._dtype
) w = w // 2
return unit_vector * gamma real_image = image[:, :w, :]
raise NotImplementedError('Noise distribution: {0} is not ' input_image = image[:, w:, :]
'a valid distribution'.format(distribution))
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
return input_image, real_image
inp, re = load(PATH + 'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(re / 255.0)
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
def normalize(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
# resizing to 286 x 286 x 3
input_image, real_image = resize(input_image, real_image, 286, 286)
# randomly cropping to 256 x 256 x 3
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
# random mirroring
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image
def load_image_train(image_file):
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
def load_image_test(image_file):
input_image, real_image = load(image_file)
input_image, real_image = resize(input_image, real_image,
IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
train_dataset = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.batch(1)
# steps_per_epoch = training_utils.infer_steps_for_dataset(
# train_dataset, None, epochs=1, steps_name='steps')
# for batch in train_dataset:
# print(batch[1].shape)
test_dataset = tf.data.Dataset.list_files(PATH + 'test/*.jpg')
# shuffling so that for every epoch a different image is generated
# to predict and display the progress of our model.
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1)
be = BoltonModel(3, 2)
from tensorflow.python.keras.optimizer_v2 import adam
from privacy.bolton import loss
test = adam.Adam()
l = loss.StrongConvexBinaryCrossentropy(1, 2, 1)
be.compile(test, l)
print("Eager exeuction: {0}".format(tf.executing_eagerly()))
be.fit(train_dataset, verbose=0, steps_per_epoch=1, n_samples=1)

View file

@ -32,7 +32,7 @@ from absl.testing import absltest
from tensorflow.python.keras.regularizers import L1L2 from tensorflow.python.keras.regularizers import L1L2
class TestLoss(losses.Loss): class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model""" """Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'): def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name) super(TestLoss, self).__init__(name=name)
@ -145,21 +145,25 @@ class InitTests(keras_parameterized.TestCase):
self.assertIsInstance(clf, model.Bolton) self.assertIsInstance(clf, model.Bolton)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'invalid noise', {'testcase_name': 'invalid noise',
'n_classes': 1, 'n_classes': 1,
'epsilon': 1, 'epsilon': 1,
'noise_distribution': 'not_valid', 'noise_distribution': 'not_valid',
'weights_initializer': tf.initializers.GlorotUniform(), 'weights_initializer': tf.initializers.GlorotUniform(),
}, },
{'testcase_name': 'invalid epsilon', {'testcase_name': 'invalid epsilon',
'n_classes': 1, 'n_classes': 1,
'epsilon': -1, 'epsilon': -1,
'noise_distribution': 'laplace', 'noise_distribution': 'laplace',
'weights_initializer': tf.initializers.GlorotUniform(), 'weights_initializer': tf.initializers.GlorotUniform(),
}, },
]) ])
def test_bad_init_params( def test_bad_init_params(
self, n_classes, epsilon, noise_distribution, weights_initializer): self,
n_classes,
epsilon,
noise_distribution,
weights_initializer):
# test invalid domains for each variable, especially noise # test invalid domains for each variable, especially noise
seed = 1 seed = 1
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -204,16 +208,16 @@ class InitTests(keras_parameterized.TestCase):
self.assertEqual(clf.loss, loss) self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'Not strong loss', {'testcase_name': 'Not strong loss',
'n_classes': 1, 'n_classes': 1,
'loss': losses.BinaryCrossentropy(), 'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam', 'optimizer': 'adam',
}, },
{'testcase_name': 'Not valid optimizer', {'testcase_name': 'Not valid optimizer',
'n_classes': 1, 'n_classes': 1,
'loss': TestLoss(1, 1, 1), 'loss': TestLoss(1, 1, 1),
'optimizer': 'ada', 'optimizer': 'ada',
} }
]) ])
def test_bad_compile(self, n_classes, loss, optimizer): def test_bad_compile(self, n_classes, loss, optimizer):
# test compilaton of invalid tf.optimizer and non instantiated loss. # test compilaton of invalid tf.optimizer and non instantiated loss.
@ -250,7 +254,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
y_stack = [] y_stack = []
for i_class in range(n_classes): for i_class in range(n_classes):
x_stack.append( x_stack.append(
tf.constant(1*i_class, tf.float32, (n_samples, input_dim)) tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
) )
y_stack.append( y_stack.append(
tf.constant(i_class, tf.float32, (n_samples, n_classes)) tf.constant(i_class, tf.float32, (n_samples, n_classes))
@ -258,7 +262,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, t='train', generator=False):
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack) x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
if generator: if generator:
dataset = tf.data.Dataset.from_tensor_slices( dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set) (x_set, y_set)
) )
return dataset return dataset
return x_set, y_set return x_set, y_set
@ -281,10 +285,10 @@ def _do_fit(n_samples,
clf.compile(optimizer, loss) clf.compile(optimizer, loss)
if generator: if generator:
x = _cat_dataset( x = _cat_dataset(
n_samples, n_samples,
input_dim, input_dim,
n_classes, n_classes,
generator=generator generator=generator
) )
y = None y = None
# x = x.batch(batch_size) # x = x.batch(batch_size)
@ -315,26 +319,26 @@ class FitTests(keras_parameterized.TestCase):
# @test_util.run_all_in_graph_and_eager_modes # @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'iterator fit', {'testcase_name': 'iterator fit',
'generator': False, 'generator': False,
'reset_n_samples': True, 'reset_n_samples': True,
'callbacks': None 'callbacks': None
}, },
{'testcase_name': 'iterator fit no samples', {'testcase_name': 'iterator fit no samples',
'generator': False, 'generator': False,
'reset_n_samples': True, 'reset_n_samples': True,
'callbacks': None 'callbacks': None
}, },
{'testcase_name': 'generator fit', {'testcase_name': 'generator fit',
'generator': True, 'generator': True,
'reset_n_samples': False, 'reset_n_samples': False,
'callbacks': None 'callbacks': None
}, },
{'testcase_name': 'with callbacks', {'testcase_name': 'with callbacks',
'generator': True, 'generator': True,
'reset_n_samples': False, 'reset_n_samples': False,
'callbacks': TestCallback() 'callbacks': TestCallback()
}, },
]) ])
def test_fit(self, generator, reset_n_samples, callbacks): def test_fit(self, generator, reset_n_samples, callbacks):
loss = TestLoss(1, 1, 1) loss = TestLoss(1, 1, 1)
@ -344,9 +348,19 @@ class FitTests(keras_parameterized.TestCase):
epsilon = 1 epsilon = 1
batch_size = 1 batch_size = 1
n_samples = 10 n_samples = 10
clf = _do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size, clf = _do_fit(
reset_n_samples, optimizer, loss, callbacks) n_samples,
self.assertEqual(hasattr(clf, '_layers'), True) input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
callbacks
)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'generator fit', {'testcase_name': 'generator fit',
@ -368,15 +382,15 @@ class FitTests(keras_parameterized.TestCase):
) )
clf.compile(optimizer, loss) clf.compile(optimizer, loss)
x = _cat_dataset( x = _cat_dataset(
n_samples, n_samples,
input_dim, input_dim,
n_classes, n_classes,
generator=generator generator=generator
) )
x = x.batch(batch_size) x = x.batch(batch_size)
x = x.shuffle(n_samples // 2) x = x.shuffle(n_samples // 2)
clf.fit_generator(x, n_samples=n_samples) clf.fit_generator(x, n_samples=n_samples)
self.assertEqual(hasattr(clf, '_layers'), True) self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'iterator no n_samples', {'testcase_name': 'iterator no n_samples',
@ -399,32 +413,43 @@ class FitTests(keras_parameterized.TestCase):
epsilon = 1 epsilon = 1
batch_size = 1 batch_size = 1
n_samples = 10 n_samples = 10
_do_fit(n_samples, input_dim, n_classes, epsilon, generator, batch_size, _do_fit(
reset_n_samples, optimizer, loss, None, distribution) n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
None,
distribution
)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'None class_weights', {'testcase_name': 'None class_weights',
'class_weights': None, 'class_weights': None,
'class_counts': None, 'class_counts': None,
'num_classes': None, 'num_classes': None,
'result': 1}, 'result': 1},
{'testcase_name': 'class weights array', {'testcase_name': 'class weights array',
'class_weights': [1, 1], 'class_weights': [1, 1],
'class_counts': [1, 1], 'class_counts': [1, 1],
'num_classes': 2, 'num_classes': 2,
'result': [1, 1]}, 'result': [1, 1]},
{'testcase_name': 'class weights balanced', {'testcase_name': 'class weights balanced',
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': [1, 1], 'class_counts': [1, 1],
'num_classes': 2, 'num_classes': 2,
'result': [1, 1]}, 'result': [1, 1]},
]) ])
def test_class_calculate(self, def test_class_calculate(self,
class_weights, class_weights,
class_counts, class_counts,
num_classes, num_classes,
result result
): ):
clf = model.Bolton(1, 1) clf = model.Bolton(1, 1)
expected = clf.calculate_class_weights(class_weights, expected = clf.calculate_class_weights(class_weights,
class_counts, class_counts,
@ -447,14 +472,14 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': None, 'class_counts': None,
'num_classes': 1, 'num_classes': 1,
'err_msg': 'err_msg': "Class counts must be provided if "
"Class counts must be provided if using class_weights=balanced"}, "using class_weights=balanced"},
{'testcase_name': 'no num classes', {'testcase_name': 'no num classes',
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': [1], 'class_counts': [1],
'num_classes': None, 'num_classes': None,
'err_msg': 'err_msg': 'num_classes must be provided if '
'num_classes must be provided if using class_weights=balanced'}, 'using class_weights=balanced'},
{'testcase_name': 'class counts not array', {'testcase_name': 'class counts not array',
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': 1, 'class_counts': 1,
@ -464,7 +489,7 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': [1], 'class_weights': [1],
'class_counts': None, 'class_counts': None,
'num_classes': None, 'num_classes': None,
'err_msg': "You must pass a value for num_classes if" 'err_msg': "You must pass a value for num_classes if "
"creating an array of class_weights"}, "creating an array of class_weights"},
{'testcase_name': 'class counts array, improper shape', {'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]], 'class_weights': [[1], [1]],
@ -481,7 +506,8 @@ class FitTests(keras_parameterized.TestCase):
class_weights, class_weights,
class_counts, class_counts,
num_classes, num_classes,
err_msg): err_msg
):
clf = model.Bolton(1, 1) clf = model.Bolton(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): with self.assertRaisesRegexp(ValueError, err_msg):
expected = clf.calculate_class_weights(class_weights, expected = clf.calculate_class_weights(class_weights,

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Private Optimizer for bolton method""" """Bolton Optimizer for bolton method"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -19,15 +19,16 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from privacy.bolton.loss import StrongConvexMixin
_private_attributes = ['_internal_optimizer', 'dtype'] _accepted_distributions = ['laplace']
class Private(optimizer_v2.OptimizerV2): class Bolton(optimizer_v2.OptimizerV2):
""" """
Private optimizer wraps another tf optimizer to be used Bolton optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer as the visible optimizer to the tf model. No matter the optimizer
passed, "Private" enables the bolton model to control the learning rate passed, "Bolton" enables the bolton model to control the learning rate
based on the strongly convex loss. based on the strongly convex loss.
For more details on the strong convexity requirements, see: For more details on the strong convexity requirements, see:
@ -36,7 +37,8 @@ class Private(optimizer_v2.OptimizerV2):
""" """
def __init__(self, def __init__(self,
optimizer: optimizer_v2.OptimizerV2, optimizer: optimizer_v2.OptimizerV2,
dtype=tf.float32 loss: StrongConvexMixin,
dtype=tf.float32,
): ):
"""Constructor. """Constructor.
@ -44,15 +46,100 @@ class Private(optimizer_v2.OptimizerV2):
optimizer: Optimizer_v2 or subclass to be used as the optimizer optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped). (wrapped).
""" """
if not isinstance(loss, StrongConvexMixin):
raise ValueError("loss function must be a Strongly Convex and therfore"
"extend the StrongConvexMixin.")
self._private_attributes = ['_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'n_classes',
'layers',
'_model'
]
self._internal_optimizer = optimizer self._internal_optimizer = optimizer
self.dtype = dtype self.dtype = dtype
self.loss = loss
def get_config(self): def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.get_config() return self._internal_optimizer.get_config()
def limit_learning_rate(self, is_eager, beta, gamma): def project_weights_to_r(self, force=False):
"""helper method to normalize the weights to the R-ball.
Args:
r: radius of "R-Ball". Scalar to normalize to.
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Returns:
"""
r = self.loss.radius()
for layer in self.layers:
if tf.executing_eagerly():
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
elif tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0:
layer.kernel = layer.kernel / (weight_norm / r)
else:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / r)
else:
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > r, dtype=self.dtype)) > 0,
lambda: layer.kernel / (weight_norm / r),
lambda: layer.kernel
)
def get_noise(self, data_size, input_dim, output_dim, class_weight):
"""Sample noise to be added to weights for privacy guarantee
Args:
distribution: the distribution type to pull noise from
data_size: the number of samples
Returns: noise in shape of layer's weights to be added to the weights.
"""
loss = self.loss
distribution = self.noise_distribution.lower()
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (output_dim)
l2_sensitivity = (2 *
loss.lipchitz_constant(class_weight)) / \
(loss.gamma() * data_size)
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
mean=0,
seed=1,
stddev=1.0,
dtype=self.dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([output_dim],
alpha,
beta=1 / beta,
seed=1,
dtype=self.dtype
)
return unit_vector * gamma
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def limit_learning_rate(self, beta, gamma):
"""Implements learning rate limitation that is required by the bolton """Implements learning rate limitation that is required by the bolton
method for sensitivity bounding of the strongly convex function. method for sensitivity bounding of the strongly convex function.
Sets the learning rate to the min(1/beta, 1/(gamma*t)) Sets the learning rate to the min(1/beta, 1/(gamma*t))
@ -65,20 +152,13 @@ class Private(optimizer_v2.OptimizerV2):
Returns: None Returns: None
""" """
numerator = tf.Variable(initial_value=1, dtype=self.dtype) numerator = tf.constant(1, dtype=self.dtype)
t = tf.cast(self._iterations, self.dtype) t = tf.cast(self._iterations, self.dtype)
# will exist on the internal optimizer # will exist on the internal optimizer
pred = numerator / beta < numerator / (gamma * t) if numerator / beta < numerator / (gamma * t):
if is_eager: # check eagerly self.learning_rate = numerator / beta
if pred:
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
else: else:
if pred: self.learning_rate = numerator / (gamma * t)
self.learning_rate = numerator / beta
else:
self.learning_rate = numerator / (gamma * t)
def from_config(self, *args, **kwargs): def from_config(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
@ -92,14 +172,25 @@ class Private(optimizer_v2.OptimizerV2):
Args: Args:
name: name:
Returns: attribute from Private if specified to come from self, else Returns: attribute from Bolton if specified to come from self, else
from _internal_optimizer. from _internal_optimizer.
""" """
if name in _private_attributes: if name == '_private_attributes':
return getattr(self, name)
elif name in self._private_attributes:
return getattr(self, name) return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer') optim = object.__getattribute__(self, '_internal_optimizer')
return object.__getattribute__(optim, name) try:
return object.__getattribute__(optim, name)
except AttributeError:
raise AttributeError("Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(
self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
)
def __setattr__(self, key, value): def __setattr__(self, key, value):
""" Set attribute to self instance if its the internal optimizer. """ Set attribute to self instance if its the internal optimizer.
@ -112,7 +203,9 @@ class Private(optimizer_v2.OptimizerV2):
Returns: Returns:
""" """
if key in _private_attributes: if key == '_private_attributes':
object.__setattr__(self, key, value)
elif key in key in self._private_attributes:
object.__setattr__(self, key, value) object.__setattr__(self, key, value)
else: else:
setattr(self._internal_optimizer, key, value) setattr(self._internal_optimizer, key, value)
@ -130,24 +223,135 @@ class Private(optimizer_v2.OptimizerV2):
def get_updates(self, loss, params): def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.get_updates(loss, params) # self.layers = params
out = self._internal_optimizer.get_updates(loss, params)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def apply_gradients(self, *args, **kwargs): def apply_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.apply_gradients(*args, **kwargs) # grads_and_vars = kwargs.get('grads_and_vars', None)
# grads_and_vars = optimizer_v2._filter_grads(grads_and_vars)
# var_list = [v for (_, v) in grads_and_vars]
# self.layers = var_list
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def minimize(self, *args, **kwargs): def minimize(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
return self._internal_optimizer.minimize(*args, **kwargs) # self.layers = kwargs.get('var_list', None)
out = self._internal_optimizer.minimize(*args, **kwargs)
self.limit_learning_rate(self.loss.beta(self.class_weights),
self.loss.gamma()
)
self.project_weights_to_r()
return out
def _compute_gradients(self, *args, **kwargs): def _compute_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
# self.layers = kwargs.get('var_list', None)
return self._internal_optimizer._compute_gradients(*args, **kwargs) return self._internal_optimizer._compute_gradients(*args, **kwargs)
def get_gradients(self, *args, **kwargs): def get_gradients(self, *args, **kwargs):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer.
""" """
# self.layers = kwargs.get('params', None)
return self._internal_optimizer.get_gradients(*args, **kwargs) return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self):
noise_distribution = self.noise_distribution
epsilon = self.epsilon
class_weights = self.class_weights
n_samples = self.n_samples
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
return self
def __call__(self,
noise_distribution,
epsilon,
layers,
class_weights,
n_samples,
n_classes,
):
"""
Args:
noise_distribution: the noise distribution to pick.
see _accepted_distributions and get_noise for
possible values.
epsilon: privacy parameter. Lower gives more privacy but less utility.
class_weights: class_weights used
n_samples number of rows/individual samples in the training set
n_classes: number of output classes
layers: list of Keras/Tensorflow layers.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
self.noise_distribution = noise_distribution
self.epsilon = epsilon
self.class_weights = class_weights
self.n_samples = n_samples
self.n_classes = n_classes
self.layers = layers
return self
def __exit__(self, *args):
"""Exit call from with statement.
used to
1.reset the model and fit parameters passed to the optimizer
to enable the Bolton Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
"""
# for param in self.layers:
# if param.name.find('kernel') != -1 or param.name.find('weight') != -1:
# input_dim = param.numpy().shape[0]
# print(param)
# noise = -1 * self.get_noise(self.n_samples,
# input_dim,
# self.n_classes,
# self.class_weights
# )
# print(tf.math.subtract(param, noise))
# param.assign(tf.math.subtract(param, noise))
self.project_weights_to_r(True)
for layer in self.layers:
input_dim, output_dim = layer.kernel.shape
noise = self.get_noise(self.n_samples,
input_dim,
output_dim,
self.class_weights
)
layer.kernel = tf.math.add(layer.kernel, noise)
self.noise_distribution = None
self.epsilon = -1
self.class_weights = None
self.n_samples = None
self.input_dim = None
self.n_classes = None
self.layers = None

View file

@ -21,19 +21,129 @@ import tensorflow as tf
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from privacy.bolton import model from tensorflow.python.keras.regularizers import L1L2
from privacy.bolton import optimizer as opt from tensorflow.python.keras import losses
from tensorflow.python.keras.models import Model
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import test_util
from absl.testing import parameterized from absl.testing import parameterized
from absl.testing import absltest from privacy.bolton.loss import StrongConvexMixin
from privacy.bolton import optimizer as opt
class TestModel(Model):
"""
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, n_classes=2):
"""
Args:
n_classes: number of output classes to predict.
epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights
seed: random seed to use
dtype: data type to use for tensors
"""
super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_classes = n_classes
self.layer_input_shape = (16, 1)
self.output_layer = tf.keras.layers.Dense(
self.n_classes,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer='glorot_uniform',
)
# def call(self, inputs):
# """Forward pass of network
#
# Args:
# inputs: inputs to neural network
#
# Returns:
#
# """
# return self.output_layer(inputs)
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model"""
def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = C
self.radius_constant = radius_constant
def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch)
Returns: radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self):
""" Gamma strongly convex
Returns: gamma
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight):
"""Beta smoothess
Args:
class_weight: the class weights used.
Returns: Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight):
""" L lipchitz continuous
Args:
class_weight: class weights used
Returns: L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, val0, val1):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(tf.math.squared_difference(val0, val1), axis=1)
def max_class_weight(self, class_weight):
if class_weight is None:
return 1
def kernel_regularizer(self):
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2): class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the Private optimizer""" """Optimizer used for testing the Bolton optimizer"""
def __init__(self): def __init__(self):
super(TestOptimizer, self).__init__('test') super(TestOptimizer, self).__init__('test')
self.not_private = 'test' self.not_private = 'test'
self.iterations = tf.Variable(1, dtype=tf.float32) self.iterations = tf.constant(1, dtype=tf.float32)
self._iterations = tf.Variable(1, dtype=tf.float32) self._iterations = tf.constant(1, dtype=tf.float32)
def _compute_gradients(self, loss, var_list, grad_loss=None): def _compute_gradients(self, loss, var_list, grad_loss=None):
return 'test' return 'test'
@ -41,7 +151,7 @@ class TestOptimizer(OptimizerV2):
def get_config(self): def get_config(self):
return 'test' return 'test'
def from_config(cls, config, custom_objects=None): def from_config(self, config, custom_objects=None):
return 'test' return 'test'
def _create_slots(self): def _create_slots(self):
@ -65,34 +175,22 @@ class TestOptimizer(OptimizerV2):
def get_gradients(self, loss, params): def get_gradients(self, loss, params):
return 'test' return 'test'
class PrivateTest(keras_parameterized.TestCase): def limit_learning_rate(self):
"""Private Optimizer tests""" return 'test'
class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Bolton Optimizer tests"""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'branch True, beta', {'testcase_name': 'branch beta',
'fn': 'limit_learning_rate', 'fn': 'limit_learning_rate',
'args': [True, 'args': [tf.Variable(2, dtype=tf.float32),
tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)], tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32), 'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'}, 'test_attr': 'learning_rate'},
{'testcase_name': 'branch True, gamma', {'testcase_name': 'branch gamma',
'fn': 'limit_learning_rate', 'fn': 'limit_learning_rate',
'args': [True, 'args': [tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, beta',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(2, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(0.5, dtype=tf.float32),
'test_attr': 'learning_rate'},
{'testcase_name': 'branch False, gamma',
'fn': 'limit_learning_rate',
'args': [False,
tf.Variable(1, dtype=tf.float32),
tf.Variable(1, dtype=tf.float32)], tf.Variable(1, dtype=tf.float32)],
'result': tf.Variable(1, dtype=tf.float32), 'result': tf.Variable(1, dtype=tf.float32),
'test_attr': 'learning_rate'}, 'test_attr': 'learning_rate'},
@ -101,9 +199,26 @@ class PrivateTest(keras_parameterized.TestCase):
'args': ['dtype'], 'args': ['dtype'],
'result': tf.float32, 'result': tf.float32,
'test_attr': None}, 'test_attr': None},
{'testcase_name': 'project_weights_to_r',
'fn': 'project_weights_to_r',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
]) ])
def test_fn(self, fn, args, result, test_attr): def test_fn(self, fn, args, result, test_attr):
private = opt.Private(TestOptimizer()) """test that a fn of Bolton optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of Bolton to check against result with.
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
private = opt.Bolton(TestOptimizer(), loss)
res = getattr(private, fn, None)(*args) res = getattr(private, fn, None)(*args)
if test_attr is not None: if test_attr is not None:
res = getattr(private, test_attr, None) res = getattr(private, test_attr, None)
@ -142,41 +257,88 @@ class PrivateTest(keras_parameterized.TestCase):
'args': [1, 1]}, 'args': [1, 1]},
]) ])
def test_rerouted_function(self, fn, args): def test_rerouted_function(self, fn, args):
""" tests that a method of the internal optimizer is correctly routed from
the Bolton instance to the internal optimizer instance (TestOptimizer,
here).
Args:
fn: fn to test
args: arguments to that fn
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer() optimizer = TestOptimizer()
optimizer = opt.Private(optimizer) optimizer = opt.Bolton(optimizer, loss)
self.assertEqual( model = TestModel(2)
getattr(optimizer, fn, lambda: 'fn not found')(*args), model.compile(optimizer, loss)
'test' model.layers[0].kernel_initializer(model.layer_input_shape)
) print(model.layers[0].__dict__)
with optimizer('laplace', 2, model.layers, 1, 1, model.n_classes):
self.assertEqual(
getattr(optimizer, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'fn: limit_learning_rate', {'testcase_name': 'fn: limit_learning_rate',
'fn': 'limit_learning_rate', 'fn': 'limit_learning_rate',
'args': [1, 1, 1]} 'args': [1, 1, 1]},
{'testcase_name': 'fn: project_weights_to_r',
'fn': 'project_weights_to_r',
'args': []},
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1, 1, 1]},
]) ])
def test_not_reroute_fn(self, fn, args): def test_not_reroute_fn(self, fn, args):
"""Test that a fn that should not be rerouted to the internal optimizer is
in face not rerouted.
Args:
fn: fn to test
args: arguments to that fn
"""
optimizer = TestOptimizer() optimizer = TestOptimizer()
optimizer = opt.Private(optimizer) loss = TestLoss(1, 1, 1)
optimizer = opt.Bolton(optimizer, loss)
self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args), self.assertNotEqual(getattr(optimizer, fn, lambda: 'test')(*args),
'test') 'test')
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'attr: not_private', {'testcase_name': 'attr: _iterations',
'attr': 'not_private'} 'attr': '_iterations'}
]) ])
def test_reroute_attr(self, attr): def test_reroute_attr(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer() internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer) optimizer = opt.Bolton(internal_optimizer, loss)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer) self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr)
)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'attr: _internal_optimizer', {'testcase_name': 'attr does not exist',
'attr': '_internal_optimizer'} 'attr': '_not_valid'}
]) ])
def test_not_reroute_attr(self, attr): def test_attribute_error(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer() internal_optimizer = TestOptimizer()
optimizer = opt.Private(internal_optimizer) optimizer = opt.Bolton(internal_optimizer, loss)
self.assertEqual(optimizer._internal_optimizer, internal_optimizer) with self.assertRaises(AttributeError):
getattr(optimizer, attr)
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()