Code style and documentation changes.

This commit is contained in:
Christopher Choquette Choo 2019-07-30 15:12:22 -04:00
parent fb12ee047f
commit 2065f2b16a
10 changed files with 105 additions and 103 deletions

View file

@ -42,8 +42,8 @@ else:
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
from privacy.bolton.models import BoltonModel
from privacy.bolton.optimizers import Bolton
from privacy.bolton.models import BoltOnModel
from privacy.bolton.optimizers import BoltOn
from privacy.bolton.losses import StrongConvexMixin
from privacy.bolton.losses import StrongConvexBinaryCrossentropy
from privacy.bolton.losses import StrongConvexHuber

View file

@ -1,25 +1,26 @@
# Bolton Subpackage
# BoltOn Subpackage
This package contains source code for the Bolton method. This method is a subset
of methods used in the ensuring privacy in machine learning that leverages
This package contains source code for the BoltOn method, a particular
differential-privacy (DP) technique that uses output perturbations and leverages
additional assumptions to provide a new way of approaching the privacy
guarantees.
## Bolton Description
## BoltOn Description
This method uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
2. Projects weights to R, the radius of the hypothesis space,
after each batch. This value is configurable by the user.
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et al.
Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf
## Why Bolton?
## Why BoltOn?
The major difference for the Bolton method is that it injects noise post model
The major difference for the BoltOn method is that it injects noise post model
convergence, rather than noising gradients or weights during training. This
approach requires some additional constraints listed in the Description.
Should the use-case and model satisfy these constraints, this is another

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton Method for privacy."""
"""BoltOn Method for privacy."""
import sys
from distutils.version import LooseVersion
import tensorflow as tf
@ -23,7 +23,7 @@ if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
pass
else:
from privacy.bolton.models import BoltonModel # pylint: disable=g-import-not-at-top
from privacy.bolton.optimizers import Bolton # pylint: disable=g-import-not-at-top
from privacy.bolton.models import BoltOnModel # pylint: disable=g-import-not-at-top
from privacy.bolton.optimizers import BoltOn # pylint: disable=g-import-not-at-top
from privacy.bolton.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top
from privacy.bolton.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top

View file

@ -29,7 +29,7 @@ class StrongConvexMixin: # pylint: disable=old-style-class
"""Strong Convex Mixin base class.
Strong Convex Mixin base class for any loss function that will be used with
Bolton model. Subclasses must be strongly convex and implement the
BoltOn model. Subclasses must be strongly convex and implement the
associated constants. They must also conform to the requirements of tf losses
(see super class).

View file

@ -372,7 +372,7 @@ class HuberTests(keras_parameterized.TestCase):
Args:
logits: unscaled output of model
y_true: label
delta:
delta: delta value for StrongConvexHuber loss.
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton model for bolton method of differentially private ML."""
"""BoltOn model for bolton method of differentially private ML."""
from __future__ import absolute_import
from __future__ import division
@ -21,11 +21,11 @@ from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import Model
from privacy.bolton.losses import StrongConvexMixin
from privacy.bolton.optimizers import Bolton
from privacy.bolton.optimizers import BoltOn
class BoltonModel(Model): # pylint: disable=abstract-method
"""Bolton episilon-delta differential privacy model.
class BoltOnModel(Model): # pylint: disable=abstract-method
"""BoltOn episilon-delta differential privacy model.
The privacy guarantees are dependent on the noise that is sampled. Please
see the paper linked below for more details.
@ -52,7 +52,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
seed: random seed to use
dtype: data type to use for tensors
"""
super(BoltonModel, self).__init__(name='bolton', dynamic=False)
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
if n_outputs <= 0:
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
n_outputs
@ -69,6 +69,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
inputs: inputs to neural network
Returns:
Output logits for the given inputs.
"""
return self.output_layer(inputs)
@ -78,11 +79,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method
loss,
kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in Bolton method is SGD.
"""See super class. Default optimizer used in BoltOn method is SGD.
Args:
optimizer: The optimizer to use. This will be automatically wrapped
with the Bolton Optimizer.
with the BoltOn Optimizer.
loss: The loss function to use. Must be a StrongConvex loss (extend the
StrongConvexMixin).
kernel_initializer: The kernel initializer to use for the single layer.
@ -99,11 +100,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method
kernel_initializer=kernel_initializer(),
)
self._layers_instantiated = True
if not isinstance(optimizer, Bolton):
if not isinstance(optimizer, BoltOn):
optimizer = optimizers.get(optimizer)
optimizer = Bolton(optimizer, loss)
optimizer = BoltOn(optimizer, loss)
super(BoltonModel, self).compile(optimizer, loss=loss, **kwargs)
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
def fit(self,
x=None,
@ -115,7 +116,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
noise_distribution='laplace',
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Reroutes to super fit with Bolton delta-epsilon privacy requirements.
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
Note, inputs must be normalized s.t. ||x|| < 1.
Requirements are as follows:
@ -126,9 +127,9 @@ class BoltonModel(Model): # pylint: disable=abstract-method
See super implementation for more details.
Args:
x:
y:
batch_size:
x: Inputs to fit on, see super.
y: Labels to fit on, see super.
batch_size: The batch size to use for training, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
n_samples: the number of individual samples in x.
@ -139,7 +140,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
**kwargs: kwargs to keras Model.fit. See super.
Returns:
output
Output from super fit method.
"""
if class_weight is None:
class_weight_ = self.calculate_class_weights(class_weight)
@ -170,7 +171,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
class_weight_,
data_size,
batch_size_) as _:
out = super(BoltonModel, self).fit(x=x,
out = super(BoltOnModel, self).fit(x=x,
y=y,
batch_size=batch_size,
class_weight=class_weight,
@ -192,18 +193,18 @@ class BoltonModel(Model): # pylint: disable=abstract-method
is a generator. See super method and fit for more details.
Args:
generator:
generator: Inputs generator following Tensorflow guidelines, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
Bolton paper for more description.
BoltOn paper for more description.
n_samples: number of individual samples in x
steps_per_epoch:
steps_per_epoch: Number of steps per training epoch, see super.
**kwargs: **kwargs
Returns:
output
Output from super fit_generator method.
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
@ -224,7 +225,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
class_weight,
data_size,
batch_size) as _:
out = super(BoltonModel, self).fit_generator(
out = super(BoltOnModel, self).fit_generator(
generator,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,

View file

@ -26,11 +26,11 @@ from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from privacy.bolton import models
from privacy.bolton.losses import StrongConvexMixin
from privacy.bolton.optimizers import Bolton
from privacy.bolton.optimizers import BoltOn
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model."""
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
@ -105,7 +105,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing Bolton model."""
"""Test optimizer used for testing BoltOn model."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
@ -138,14 +138,14 @@ class InitTests(keras_parameterized.TestCase):
},
])
def test_init_params(self, n_outputs):
"""Test initialization of BoltonModel.
"""Test initialization of BoltOnModel.
Args:
n_outputs: number of output neurons
"""
# test valid domains for each variable
clf = models.BoltonModel(n_outputs)
self.assertIsInstance(clf, models.BoltonModel)
clf = models.BoltOnModel(n_outputs)
self.assertIsInstance(clf, models.BoltOnModel)
@parameterized.named_parameters([
{'testcase_name': 'invalid n_outputs',
@ -153,14 +153,14 @@ class InitTests(keras_parameterized.TestCase):
},
])
def test_bad_init_params(self, n_outputs):
"""test bad initializations of BoltonModel that should raise errors.
"""test bad initializations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
"""
# test invalid domains for each variable, especially noise
with self.assertRaises(ValueError):
models.BoltonModel(n_outputs)
models.BoltOnModel(n_outputs)
@parameterized.named_parameters([
{'testcase_name': 'string compile',
@ -175,7 +175,7 @@ class InitTests(keras_parameterized.TestCase):
},
])
def test_compile(self, n_outputs, loss, optimizer):
"""Test compilation of BoltonModel.
"""Test compilation of BoltOnModel.
Args:
n_outputs: number of output neurons
@ -184,7 +184,7 @@ class InitTests(keras_parameterized.TestCase):
"""
# test compilation of valid tf.optimizer and tf.loss
with self.cached_session():
clf = models.BoltonModel(n_outputs)
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
self.assertEqual(clf.loss, loss)
@ -201,7 +201,7 @@ class InitTests(keras_parameterized.TestCase):
}
])
def test_bad_compile(self, n_outputs, loss, optimizer):
"""test bad compilations of BoltonModel that should raise errors.
"""test bad compilations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
@ -211,7 +211,7 @@ class InitTests(keras_parameterized.TestCase):
# test compilaton of invalid tf.optimizer and non instantiated loss.
with self.cached_session():
with self.assertRaises((ValueError, AttributeError)):
clf = models.BoltonModel(n_outputs)
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
@ -276,9 +276,9 @@ def _do_fit(n_samples,
distribution: distribution to get noise from.
Returns:
BoltonModel instsance
BoltOnModel instsance
"""
clf = models.BoltonModel(n_outputs)
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
@ -328,14 +328,14 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_fit(self, generator, reset_n_samples):
"""Tests fitting of BoltonModel.
"""Tests fitting of BoltOnModel.
Args:
generator: True for generator test, False for iterator test.
reset_n_samples: True to reset the n_samples to None, False does nothing
"""
loss = TestLoss(1, 1, 1)
optimizer = Bolton(TestOptimizer(), loss)
optimizer = BoltOn(TestOptimizer(), loss)
n_classes = 2
input_dim = 5
epsilon = 1
@ -360,7 +360,7 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_fit_gen(self, generator):
"""Tests the fit_generator method of BoltonModel.
"""Tests the fit_generator method of BoltOnModel.
Args:
generator: True to test with a generator dataset
@ -371,7 +371,7 @@ class FitTests(keras_parameterized.TestCase):
input_dim = 5
batch_size = 1
n_samples = 10
clf = models.BoltonModel(n_classes)
clf = models.BoltOnModel(n_classes)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
@ -456,7 +456,7 @@ class FitTests(keras_parameterized.TestCase):
num_classes: number of outputs neurons
result: expected result
"""
clf = models.BoltonModel(1, 1)
clf = models.BoltOnModel(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes)
@ -523,7 +523,7 @@ class FitTests(keras_parameterized.TestCase):
num_classes: number of outputs neurons.
err_msg: The expected error message.
"""
clf = models.BoltonModel(1, 1)
clf = models.BoltOnModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights,
class_counts,

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton Optimizer for bolton method."""
"""BoltOn Optimizer for bolton method."""
from __future__ import absolute_import
from __future__ import division
@ -45,14 +45,14 @@ class GammaBetaDecreasingStep(
Returns:
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
the Bolton privacy requirements.
the BoltOn privacy requirements.
"""
if not self.is_init:
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
'This is performed automatically by using the '
'{1} as a context manager, '
'as desired'.format(self.__class__.__name__,
Bolton.__class__.__name__
BoltOn.__class__.__name__
)
)
dtype = self.beta.dtype
@ -86,22 +86,22 @@ class GammaBetaDecreasingStep(
self.gamma = None
class Bolton(optimizer_v2.OptimizerV2):
"""Wrap another tf optimizer with Bolton privacy protocol.
class BoltOn(optimizer_v2.OptimizerV2):
"""Wrap another tf optimizer with BoltOn privacy protocol.
Bolton optimizer wraps another tf optimizer to be used
BoltOn optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "Bolton" enables the bolton model to control the learning rate
passed, "BoltOn" enables the bolton model to control the learning rate
based on the strongly convex loss.
To use the Bolton method, you must:
To use the BoltOn method, you must:
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
2. use it as a context manager around your .fit method internals.
This can be accomplished by the following:
optimizer = tf.optimizers.SGD()
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
bolton = Bolton(optimizer, loss)
bolton = BoltOn(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
@ -142,7 +142,7 @@ class Bolton(optimizer_v2.OptimizerV2):
'_is_init'
]
self._internal_optimizer = optimizer
self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
# rate scheduler, as required for privacy guarantees. This will still need
# to get values from the loss function near the time that .fit is called
# on the model (when this optimizer will be called as a context manager)
@ -162,7 +162,7 @@ class Bolton(optimizer_v2.OptimizerV2):
False to check if weights > R-ball and only normalize then.
Raises:
Exception:
Exception: If not called from inside this optimizer context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
@ -190,7 +190,7 @@ class Bolton(optimizer_v2.OptimizerV2):
Noise in shape of layer's weights to be added to the weights.
Raises:
Exception:
Exception: If not called from inside this optimizer's context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
@ -234,10 +234,10 @@ class Bolton(optimizer_v2.OptimizerV2):
from the _internal_optimizer instance.
Args:
name:
name: Name of attribute to get from this or aggregate optimizer.
Returns:
attribute from Bolton if specified to come from self, else
attribute from BoltOn if specified to come from self, else
from _internal_optimizer.
"""
if name == '_private_attributes' or name in self._private_attributes:
@ -336,7 +336,7 @@ class Bolton(optimizer_v2.OptimizerV2):
batch_size: batch size used.
Returns:
self
self, to be used by the __enter__ method for context.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
@ -361,7 +361,7 @@ class Bolton(optimizer_v2.OptimizerV2):
Used to:
1.reset the model and fit parameters passed to the optimizer
to enable the Bolton Privacy guarantees. These are reset to ensure
to enable the BoltOn Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
@ -369,7 +369,7 @@ class Bolton(optimizer_v2.OptimizerV2):
adding noise to the weights.
Args:
*args: *args
*args: encompasses the type, value, and traceback values which are unused.
"""
self.project_weights_to_r(True)
for layer in self.layers:

View file

@ -33,7 +33,7 @@ from privacy.bolton.losses import StrongConvexMixin
class TestModel(Model): # pylint: disable=abstract-method
"""Bolton episilon-delta model.
"""BoltOn episilon-delta model.
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
@ -66,7 +66,7 @@ class TestModel(Model): # pylint: disable=abstract-method
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model."""
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
@ -142,7 +142,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the Bolton optimizer."""
"""Optimizer used for testing the BoltOn optimizer."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
@ -185,7 +185,7 @@ class TestOptimizer(OptimizerV2):
class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Bolton Optimizer tests."""
"""BoltOn Optimizer tests."""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'getattr',
@ -201,19 +201,19 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
])
def test_fn(self, fn, args, result, test_attr):
"""test that a fn of Bolton optimizer is working as expected.
"""test that a fn of BoltOn optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of Bolton to check against result with.
the attribute of BoltOn to check against result with.
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
@ -260,7 +260,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'result': [[1]]},
])
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of Bolton optimizer is working as expected.
"""test that a fn of BoltOn optimizer is working as expected.
Args:
r: Radius value for StrongConvex loss function.
@ -273,7 +273,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
@tf.function
def project_fn(r):
loss = TestLoss(1, 1, r)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(n_out, shape, init_value)
model.compile(bolton, loss)
model.layers[0].kernel = \
@ -308,7 +308,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
@tf.function
def test_run():
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
@ -343,7 +343,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
@tf.function
def test_run(noise, epsilon):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
@ -371,7 +371,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
@ -415,7 +415,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Tests rerouted function.
Tests that a method of the internal optimizer is correctly routed from
the Bolton instance to the internal optimizer instance (TestOptimizer,
the BoltOn instance to the internal optimizer instance (TestOptimizer,
here).
Args:
@ -424,7 +424,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
bolton = opt.Bolton(optimizer, loss)
bolton = opt.BoltOn(optimizer, loss)
model = TestModel(3)
model.compile(optimizer, loss)
model.layers[0].kernel = \
@ -465,7 +465,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.Bolton(TestOptimizer(), loss)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
@ -502,7 +502,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.Bolton(internal_optimizer, loss)
optimizer = opt.BoltOn(internal_optimizer, loss)
self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr))
@ -521,7 +521,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.Bolton(internal_optimizer, loss)
optimizer = opt.BoltOn(internal_optimizer, loss)
with self.assertRaises(AttributeError):
getattr(optimizer, attr)

View file

@ -18,7 +18,7 @@ from __future__ import print_function
import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolton import losses # pylint: disable=wrong-import-position
from privacy.bolton import models # pylint: disable=wrong-import-position
from privacy.bolton.optimizers import Bolton # pylint: disable=wrong-import-position
from privacy.bolton.optimizers import BoltOn # pylint: disable=wrong-import-position
# -------
# First, we will create a binary classification dataset with a single output
# dimension. The samples for each label are repeated data points at different
@ -39,12 +39,12 @@ generator = tf.data.Dataset.from_tensor_slices((x, y))
generator = generator.batch(10)
generator = generator.shuffle(10)
# -------
# First, we will explore using the pre - built BoltonModel, which is a thin
# First, we will explore using the pre - built BoltOnModel, which is a thin
# wrapper around a Keras Model using a single - layer neural network.
# It automatically uses the Bolton Optimizer which encompasses all the logic
# required for the Bolton Differential Privacy method.
# It automatically uses the BoltOn Optimizer which encompasses all the logic
# required for the BoltOn Differential Privacy method.
# -------
bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have.
bolt = models.BoltOnModel(n_outputs) # tell the model how many outputs we have.
# -------
# Now, we will pick our optimizer and Strongly Convex Loss function. The loss
# must extend from StrongConvexMixin and implement the associated methods.Some
@ -60,7 +60,7 @@ loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
# to be 1; these are all tunable and their impact can be read in losses.
# StrongConvexBinaryCrossentropy.We then compile the model with the chosen
# optimizer and loss, which will automatically wrap the chosen optimizer with
# the Bolton Optimizer, ensuring the required components function as required
# the BoltOn Optimizer, ensuring the required components function as required
# for privacy guarantees.
# -------
bolt.compile(optimizer, loss)
@ -77,7 +77,7 @@ bolt.compile(optimizer, loss)
# 2. noise_distribution, a valid string indicating the distriution to use (must
# be implemented)
#
# The BoltonModel offers a helper method,.calculate_class_weight to aid in
# The BoltOnModel offers a helper method,.calculate_class_weight to aid in
# class_weight calculation.
# required parameters
# -------
@ -132,7 +132,7 @@ bolt.fit(generator,
noise_distribution=noise_distribution,
verbose=0)
# -------
# You don't have to use the bolton model to use the Bolton method.
# You don't have to use the bolton model to use the BoltOn method.
# There are only a few requirements:
# 1. make sure any requirements from the loss are implemented in the model.
# 2. instantiate the optimizer and use it as a context around the fit operation.
@ -140,7 +140,7 @@ bolt.fit(generator,
# -------------------- Part 2, using the Optimizer
# -------
# Here, we create our own model and setup the Bolton optimizer.
# Here, we create our own model and setup the BoltOn optimizer.
# -------
@ -157,7 +157,7 @@ class TestModel(tf.keras.Model): # pylint: disable=abstract-method
optimizer = tf.optimizers.SGD()
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
optimizer = Bolton(optimizer, loss)
optimizer = BoltOn(optimizer, loss)
# -------
# Now, we instantiate our model and check for 1. Since our loss requires L2
# regularization over the kernel, we will pass it to the model.
@ -166,7 +166,7 @@ n_outputs = 1 # parameter for model and optimizer context.
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
test_model.compile(optimizer, loss)
# -------
# We comply with 2., and use the Bolton Optimizer as a context around the fit
# We comply with 2., and use the BoltOn Optimizer as a context around the fit
# method.
# -------
# parameters for context