From 2065f2b16a42c5c604afb12e3502ac3729418883 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Tue, 30 Jul 2019 15:12:22 -0400 Subject: [PATCH] Code style and documentation changes. --- privacy/__init__.py | 4 +-- privacy/bolton/README.md | 19 +++++++------- privacy/bolton/__init__.py | 6 ++--- privacy/bolton/losses.py | 2 +- privacy/bolton/losses_test.py | 2 +- privacy/bolton/models.py | 43 ++++++++++++++++--------------- privacy/bolton/models_test.py | 40 ++++++++++++++-------------- privacy/bolton/optimizers.py | 36 +++++++++++++------------- privacy/bolton/optimizers_test.py | 34 ++++++++++++------------ tutorials/bolton_tutorial.py | 22 ++++++++-------- 10 files changed, 105 insertions(+), 103 deletions(-) diff --git a/privacy/__init__.py b/privacy/__init__.py index e494c62..94add1e 100644 --- a/privacy/__init__.py +++ b/privacy/__init__.py @@ -42,8 +42,8 @@ else: from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer - from privacy.bolton.models import BoltonModel - from privacy.bolton.optimizers import Bolton + from privacy.bolton.models import BoltOnModel + from privacy.bolton.optimizers import BoltOn from privacy.bolton.losses import StrongConvexMixin from privacy.bolton.losses import StrongConvexBinaryCrossentropy from privacy.bolton.losses import StrongConvexHuber diff --git a/privacy/bolton/README.md b/privacy/bolton/README.md index 4aef36f..54eb91a 100644 --- a/privacy/bolton/README.md +++ b/privacy/bolton/README.md @@ -1,25 +1,26 @@ -# Bolton Subpackage +# BoltOn Subpackage -This package contains source code for the Bolton method. This method is a subset -of methods used in the ensuring privacy in machine learning that leverages -additional assumptions to provide a new way of approaching the privacy +This package contains source code for the BoltOn method, a particular +differential-privacy (DP) technique that uses output perturbations and leverages +additional assumptions to provide a new way of approaching the privacy guarantees. -## Bolton Description +## BoltOn Description This method uses 4 key steps to achieve privacy guarantees: 1. Adds noise to weights after training (output perturbation). - 2. Projects weights to R after each batch + 2. Projects weights to R, the radius of the hypothesis space, + after each batch. This value is configurable by the user. 3. Limits learning rate 4. Use a strongly convex loss function (see compile) For more details on the strong convexity requirements, see: Bolt-on Differential Privacy for Scalable Stochastic Gradient -Descent-based Analytics by Xi Wu et al. +Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf -## Why Bolton? +## Why BoltOn? -The major difference for the Bolton method is that it injects noise post model +The major difference for the BoltOn method is that it injects noise post model convergence, rather than noising gradients or weights during training. This approach requires some additional constraints listed in the Description. Should the use-case and model satisfy these constraints, this is another diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py index 5dc3940..bc7a027 100644 --- a/privacy/bolton/__init__.py +++ b/privacy/bolton/__init__.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Bolton Method for privacy.""" +"""BoltOn Method for privacy.""" import sys from distutils.version import LooseVersion import tensorflow as tf @@ -23,7 +23,7 @@ if LooseVersion(tf.__version__) < LooseVersion("2.0.0"): if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts. pass else: - from privacy.bolton.models import BoltonModel # pylint: disable=g-import-not-at-top - from privacy.bolton.optimizers import Bolton # pylint: disable=g-import-not-at-top + from privacy.bolton.models import BoltOnModel # pylint: disable=g-import-not-at-top + from privacy.bolton.optimizers import BoltOn # pylint: disable=g-import-not-at-top from privacy.bolton.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top from privacy.bolton.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top diff --git a/privacy/bolton/losses.py b/privacy/bolton/losses.py index 880b8c5..c742326 100644 --- a/privacy/bolton/losses.py +++ b/privacy/bolton/losses.py @@ -29,7 +29,7 @@ class StrongConvexMixin: # pylint: disable=old-style-class """Strong Convex Mixin base class. Strong Convex Mixin base class for any loss function that will be used with - Bolton model. Subclasses must be strongly convex and implement the + BoltOn model. Subclasses must be strongly convex and implement the associated constants. They must also conform to the requirements of tf losses (see super class). diff --git a/privacy/bolton/losses_test.py b/privacy/bolton/losses_test.py index 6c60c35..ff8137c 100644 --- a/privacy/bolton/losses_test.py +++ b/privacy/bolton/losses_test.py @@ -372,7 +372,7 @@ class HuberTests(keras_parameterized.TestCase): Args: logits: unscaled output of model y_true: label - delta: + delta: delta value for StrongConvexHuber loss. result: correct loss calculation value """ logits = tf.Variable(logits, False, dtype=tf.float32) diff --git a/privacy/bolton/models.py b/privacy/bolton/models.py index 10c19b7..ad0f59c 100644 --- a/privacy/bolton/models.py +++ b/privacy/bolton/models.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Bolton model for bolton method of differentially private ML.""" +"""BoltOn model for bolton method of differentially private ML.""" from __future__ import absolute_import from __future__ import division @@ -21,11 +21,11 @@ from tensorflow.python.framework import ops as _ops from tensorflow.python.keras import optimizers from tensorflow.python.keras.models import Model from privacy.bolton.losses import StrongConvexMixin -from privacy.bolton.optimizers import Bolton +from privacy.bolton.optimizers import BoltOn -class BoltonModel(Model): # pylint: disable=abstract-method - """Bolton episilon-delta differential privacy model. +class BoltOnModel(Model): # pylint: disable=abstract-method + """BoltOn episilon-delta differential privacy model. The privacy guarantees are dependent on the noise that is sampled. Please see the paper linked below for more details. @@ -52,7 +52,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method seed: random seed to use dtype: data type to use for tensors """ - super(BoltonModel, self).__init__(name='bolton', dynamic=False) + super(BoltOnModel, self).__init__(name='bolton', dynamic=False) if n_outputs <= 0: raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format( n_outputs @@ -69,6 +69,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method inputs: inputs to neural network Returns: + Output logits for the given inputs. """ return self.output_layer(inputs) @@ -78,11 +79,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method loss, kernel_initializer=tf.initializers.GlorotUniform, **kwargs): # pylint: disable=arguments-differ - """See super class. Default optimizer used in Bolton method is SGD. + """See super class. Default optimizer used in BoltOn method is SGD. Args: optimizer: The optimizer to use. This will be automatically wrapped - with the Bolton Optimizer. + with the BoltOn Optimizer. loss: The loss function to use. Must be a StrongConvex loss (extend the StrongConvexMixin). kernel_initializer: The kernel initializer to use for the single layer. @@ -99,11 +100,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method kernel_initializer=kernel_initializer(), ) self._layers_instantiated = True - if not isinstance(optimizer, Bolton): + if not isinstance(optimizer, BoltOn): optimizer = optimizers.get(optimizer) - optimizer = Bolton(optimizer, loss) + optimizer = BoltOn(optimizer, loss) - super(BoltonModel, self).compile(optimizer, loss=loss, **kwargs) + super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs) def fit(self, x=None, @@ -115,7 +116,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method noise_distribution='laplace', steps_per_epoch=None, **kwargs): # pylint: disable=arguments-differ - """Reroutes to super fit with Bolton delta-epsilon privacy requirements. + """Reroutes to super fit with BoltOn delta-epsilon privacy requirements. Note, inputs must be normalized s.t. ||x|| < 1. Requirements are as follows: @@ -126,9 +127,9 @@ class BoltonModel(Model): # pylint: disable=abstract-method See super implementation for more details. Args: - x: - y: - batch_size: + x: Inputs to fit on, see super. + y: Labels to fit on, see super. + batch_size: The batch size to use for training, see super. class_weight: the class weights to be used. Can be a scalar or 1D tensor whose dim == n_classes. n_samples: the number of individual samples in x. @@ -139,7 +140,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method **kwargs: kwargs to keras Model.fit. See super. Returns: - output + Output from super fit method. """ if class_weight is None: class_weight_ = self.calculate_class_weights(class_weight) @@ -170,7 +171,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method class_weight_, data_size, batch_size_) as _: - out = super(BoltonModel, self).fit(x=x, + out = super(BoltOnModel, self).fit(x=x, y=y, batch_size=batch_size, class_weight=class_weight, @@ -192,18 +193,18 @@ class BoltonModel(Model): # pylint: disable=abstract-method is a generator. See super method and fit for more details. Args: - generator: + generator: Inputs generator following Tensorflow guidelines, see super. class_weight: the class weights to be used. Can be a scalar or 1D tensor whose dim == n_classes. noise_distribution: the distribution to get noise from. epsilon: privacy parameter, which trades off utility and privacy. See - Bolton paper for more description. + BoltOn paper for more description. n_samples: number of individual samples in x - steps_per_epoch: + steps_per_epoch: Number of steps per training epoch, see super. **kwargs: **kwargs Returns: - output + Output from super fit_generator method. """ if class_weight is None: class_weight = self.calculate_class_weights(class_weight) @@ -224,7 +225,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method class_weight, data_size, batch_size) as _: - out = super(BoltonModel, self).fit_generator( + out = super(BoltOnModel, self).fit_generator( generator, class_weight=class_weight, steps_per_epoch=steps_per_epoch, diff --git a/privacy/bolton/models_test.py b/privacy/bolton/models_test.py index f5365fe..b252312 100644 --- a/privacy/bolton/models_test.py +++ b/privacy/bolton/models_test.py @@ -26,11 +26,11 @@ from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 from tensorflow.python.keras.regularizers import L1L2 from privacy.bolton import models from privacy.bolton.losses import StrongConvexMixin -from privacy.bolton.optimizers import Bolton +from privacy.bolton.optimizers import BoltOn class TestLoss(losses.Loss, StrongConvexMixin): - """Test loss function for testing Bolton model.""" + """Test loss function for testing BoltOn model.""" def __init__(self, reg_lambda, c_arg, radius_constant, name='test'): super(TestLoss, self).__init__(name=name) @@ -105,7 +105,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): class TestOptimizer(OptimizerV2): - """Test optimizer used for testing Bolton model.""" + """Test optimizer used for testing BoltOn model.""" def __init__(self): super(TestOptimizer, self).__init__('test') @@ -138,14 +138,14 @@ class InitTests(keras_parameterized.TestCase): }, ]) def test_init_params(self, n_outputs): - """Test initialization of BoltonModel. + """Test initialization of BoltOnModel. Args: n_outputs: number of output neurons """ # test valid domains for each variable - clf = models.BoltonModel(n_outputs) - self.assertIsInstance(clf, models.BoltonModel) + clf = models.BoltOnModel(n_outputs) + self.assertIsInstance(clf, models.BoltOnModel) @parameterized.named_parameters([ {'testcase_name': 'invalid n_outputs', @@ -153,14 +153,14 @@ class InitTests(keras_parameterized.TestCase): }, ]) def test_bad_init_params(self, n_outputs): - """test bad initializations of BoltonModel that should raise errors. + """test bad initializations of BoltOnModel that should raise errors. Args: n_outputs: number of output neurons """ # test invalid domains for each variable, especially noise with self.assertRaises(ValueError): - models.BoltonModel(n_outputs) + models.BoltOnModel(n_outputs) @parameterized.named_parameters([ {'testcase_name': 'string compile', @@ -175,7 +175,7 @@ class InitTests(keras_parameterized.TestCase): }, ]) def test_compile(self, n_outputs, loss, optimizer): - """Test compilation of BoltonModel. + """Test compilation of BoltOnModel. Args: n_outputs: number of output neurons @@ -184,7 +184,7 @@ class InitTests(keras_parameterized.TestCase): """ # test compilation of valid tf.optimizer and tf.loss with self.cached_session(): - clf = models.BoltonModel(n_outputs) + clf = models.BoltOnModel(n_outputs) clf.compile(optimizer, loss) self.assertEqual(clf.loss, loss) @@ -201,7 +201,7 @@ class InitTests(keras_parameterized.TestCase): } ]) def test_bad_compile(self, n_outputs, loss, optimizer): - """test bad compilations of BoltonModel that should raise errors. + """test bad compilations of BoltOnModel that should raise errors. Args: n_outputs: number of output neurons @@ -211,7 +211,7 @@ class InitTests(keras_parameterized.TestCase): # test compilaton of invalid tf.optimizer and non instantiated loss. with self.cached_session(): with self.assertRaises((ValueError, AttributeError)): - clf = models.BoltonModel(n_outputs) + clf = models.BoltOnModel(n_outputs) clf.compile(optimizer, loss) @@ -276,9 +276,9 @@ def _do_fit(n_samples, distribution: distribution to get noise from. Returns: - BoltonModel instsance + BoltOnModel instsance """ - clf = models.BoltonModel(n_outputs) + clf = models.BoltOnModel(n_outputs) clf.compile(optimizer, loss) if generator: x = _cat_dataset( @@ -328,14 +328,14 @@ class FitTests(keras_parameterized.TestCase): }, ]) def test_fit(self, generator, reset_n_samples): - """Tests fitting of BoltonModel. + """Tests fitting of BoltOnModel. Args: generator: True for generator test, False for iterator test. reset_n_samples: True to reset the n_samples to None, False does nothing """ loss = TestLoss(1, 1, 1) - optimizer = Bolton(TestOptimizer(), loss) + optimizer = BoltOn(TestOptimizer(), loss) n_classes = 2 input_dim = 5 epsilon = 1 @@ -360,7 +360,7 @@ class FitTests(keras_parameterized.TestCase): }, ]) def test_fit_gen(self, generator): - """Tests the fit_generator method of BoltonModel. + """Tests the fit_generator method of BoltOnModel. Args: generator: True to test with a generator dataset @@ -371,7 +371,7 @@ class FitTests(keras_parameterized.TestCase): input_dim = 5 batch_size = 1 n_samples = 10 - clf = models.BoltonModel(n_classes) + clf = models.BoltOnModel(n_classes) clf.compile(optimizer, loss) x = _cat_dataset( n_samples, @@ -456,7 +456,7 @@ class FitTests(keras_parameterized.TestCase): num_classes: number of outputs neurons result: expected result """ - clf = models.BoltonModel(1, 1) + clf = models.BoltOnModel(1, 1) expected = clf.calculate_class_weights(class_weights, class_counts, num_classes) @@ -523,7 +523,7 @@ class FitTests(keras_parameterized.TestCase): num_classes: number of outputs neurons. err_msg: The expected error message. """ - clf = models.BoltonModel(1, 1) + clf = models.BoltOnModel(1, 1) with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method clf.calculate_class_weights(class_weights, class_counts, diff --git a/privacy/bolton/optimizers.py b/privacy/bolton/optimizers.py index a18c636..d647bbb 100644 --- a/privacy/bolton/optimizers.py +++ b/privacy/bolton/optimizers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Bolton Optimizer for bolton method.""" +"""BoltOn Optimizer for bolton method.""" from __future__ import absolute_import from __future__ import division @@ -45,15 +45,15 @@ class GammaBetaDecreasingStep( Returns: decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per - the Bolton privacy requirements. + the BoltOn privacy requirements. """ if not self.is_init: raise AttributeError('Please initialize the {0} Learning Rate Scheduler.' 'This is performed automatically by using the ' '{1} as a context manager, ' 'as desired'.format(self.__class__.__name__, - Bolton.__class__.__name__ - ) + BoltOn.__class__.__name__ + ) ) dtype = self.beta.dtype one = tf.constant(1, dtype) @@ -86,22 +86,22 @@ class GammaBetaDecreasingStep( self.gamma = None -class Bolton(optimizer_v2.OptimizerV2): - """Wrap another tf optimizer with Bolton privacy protocol. +class BoltOn(optimizer_v2.OptimizerV2): + """Wrap another tf optimizer with BoltOn privacy protocol. - Bolton optimizer wraps another tf optimizer to be used + BoltOn optimizer wraps another tf optimizer to be used as the visible optimizer to the tf model. No matter the optimizer - passed, "Bolton" enables the bolton model to control the learning rate + passed, "BoltOn" enables the bolton model to control the learning rate based on the strongly convex loss. - To use the Bolton method, you must: + To use the BoltOn method, you must: 1. instantiate it with an instantiated tf optimizer and StrongConvexLoss. 2. use it as a context manager around your .fit method internals. This can be accomplished by the following: optimizer = tf.optimizers.SGD() loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy() - bolton = Bolton(optimizer, loss) + bolton = BoltOn(optimizer, loss) with bolton(*args) as _: model.fit() The args required for the context manager can be found in the __call__ @@ -142,7 +142,7 @@ class Bolton(optimizer_v2.OptimizerV2): '_is_init' ] self._internal_optimizer = optimizer - self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning + self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning # rate scheduler, as required for privacy guarantees. This will still need # to get values from the loss function near the time that .fit is called # on the model (when this optimizer will be called as a context manager) @@ -162,7 +162,7 @@ class Bolton(optimizer_v2.OptimizerV2): False to check if weights > R-ball and only normalize then. Raises: - Exception: + Exception: If not called from inside this optimizer context. """ if not self._is_init: raise Exception('This method must be called from within the optimizer\'s ' @@ -190,7 +190,7 @@ class Bolton(optimizer_v2.OptimizerV2): Noise in shape of layer's weights to be added to the weights. Raises: - Exception: + Exception: If not called from inside this optimizer's context. """ if not self._is_init: raise Exception('This method must be called from within the optimizer\'s ' @@ -234,10 +234,10 @@ class Bolton(optimizer_v2.OptimizerV2): from the _internal_optimizer instance. Args: - name: + name: Name of attribute to get from this or aggregate optimizer. Returns: - attribute from Bolton if specified to come from self, else + attribute from BoltOn if specified to come from self, else from _internal_optimizer. """ if name == '_private_attributes' or name in self._private_attributes: @@ -336,7 +336,7 @@ class Bolton(optimizer_v2.OptimizerV2): batch_size: batch size used. Returns: - self + self, to be used by the __enter__ method for context. """ if epsilon <= 0: raise ValueError('Detected epsilon: {0}. ' @@ -361,7 +361,7 @@ class Bolton(optimizer_v2.OptimizerV2): Used to: 1.reset the model and fit parameters passed to the optimizer - to enable the Bolton Privacy guarantees. These are reset to ensure + to enable the BoltOn Privacy guarantees. These are reset to ensure that any future calls to fit with the same instance of the optimizer will properly error out. @@ -369,7 +369,7 @@ class Bolton(optimizer_v2.OptimizerV2): adding noise to the weights. Args: - *args: *args + *args: encompasses the type, value, and traceback values which are unused. """ self.project_weights_to_r(True) for layer in self.layers: diff --git a/privacy/bolton/optimizers_test.py b/privacy/bolton/optimizers_test.py index 4b08d66..abfffdd 100644 --- a/privacy/bolton/optimizers_test.py +++ b/privacy/bolton/optimizers_test.py @@ -33,7 +33,7 @@ from privacy.bolton.losses import StrongConvexMixin class TestModel(Model): # pylint: disable=abstract-method - """Bolton episilon-delta model. + """BoltOn episilon-delta model. Uses 4 key steps to achieve privacy guarantees: 1. Adds noise to weights after training (output perturbation). @@ -66,7 +66,7 @@ class TestModel(Model): # pylint: disable=abstract-method class TestLoss(losses.Loss, StrongConvexMixin): - """Test loss function for testing Bolton model.""" + """Test loss function for testing BoltOn model.""" def __init__(self, reg_lambda, c_arg, radius_constant, name='test'): super(TestLoss, self).__init__(name=name) @@ -142,7 +142,7 @@ class TestLoss(losses.Loss, StrongConvexMixin): class TestOptimizer(OptimizerV2): - """Optimizer used for testing the Bolton optimizer.""" + """Optimizer used for testing the BoltOn optimizer.""" def __init__(self): super(TestOptimizer, self).__init__('test') @@ -185,7 +185,7 @@ class TestOptimizer(OptimizerV2): class BoltonOptimizerTest(keras_parameterized.TestCase): - """Bolton Optimizer tests.""" + """BoltOn Optimizer tests.""" @test_util.run_all_in_graph_and_eager_modes @parameterized.named_parameters([ {'testcase_name': 'getattr', @@ -201,19 +201,19 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): ]) def test_fn(self, fn, args, result, test_attr): - """test that a fn of Bolton optimizer is working as expected. + """test that a fn of BoltOn optimizer is working as expected. Args: fn: method of Optimizer to test args: args to optimizer fn result: the expected result test_attr: None if the fn returns the test result. Otherwise, this is - the attribute of Bolton to check against result with. + the attribute of BoltOn to check against result with. """ tf.random.set_seed(1) loss = TestLoss(1, 1, 1) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(1) model.layers[0].kernel = \ model.layers[0].kernel_initializer((model.layer_input_shape[0], @@ -260,7 +260,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): 'result': [[1]]}, ]) def test_project(self, r, shape, n_out, init_value, result): - """test that a fn of Bolton optimizer is working as expected. + """test that a fn of BoltOn optimizer is working as expected. Args: r: Radius value for StrongConvex loss function. @@ -273,7 +273,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): @tf.function def project_fn(r): loss = TestLoss(1, 1, r) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(n_out, shape, init_value) model.compile(bolton, loss) model.layers[0].kernel = \ @@ -308,7 +308,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): @tf.function def test_run(): loss = TestLoss(1, 1, 1) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(1, (1,), 1) model.compile(bolton, loss) model.layers[0].kernel = \ @@ -343,7 +343,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): @tf.function def test_run(noise, epsilon): loss = TestLoss(1, 1, 1) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(1, (1,), 1) model.compile(bolton, loss) model.layers[0].kernel = \ @@ -371,7 +371,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): @tf.function def test_run(fn, args): loss = TestLoss(1, 1, 1) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(1, (1,), 1) model.compile(bolton, loss) model.layers[0].kernel = \ @@ -415,7 +415,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): """Tests rerouted function. Tests that a method of the internal optimizer is correctly routed from - the Bolton instance to the internal optimizer instance (TestOptimizer, + the BoltOn instance to the internal optimizer instance (TestOptimizer, here). Args: @@ -424,7 +424,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): """ loss = TestLoss(1, 1, 1) optimizer = TestOptimizer() - bolton = opt.Bolton(optimizer, loss) + bolton = opt.BoltOn(optimizer, loss) model = TestModel(3) model.compile(optimizer, loss) model.layers[0].kernel = \ @@ -465,7 +465,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): @tf.function def test_run(fn, args): loss = TestLoss(1, 1, 1) - bolton = opt.Bolton(TestOptimizer(), loss) + bolton = opt.BoltOn(TestOptimizer(), loss) model = TestModel(1, (1,), 1) model.compile(bolton, loss) model.layers[0].kernel = \ @@ -502,7 +502,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): """ loss = TestLoss(1, 1, 1) internal_optimizer = TestOptimizer() - optimizer = opt.Bolton(internal_optimizer, loss) + optimizer = opt.BoltOn(internal_optimizer, loss) self.assertEqual(getattr(optimizer, attr), getattr(internal_optimizer, attr)) @@ -521,7 +521,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase): """ loss = TestLoss(1, 1, 1) internal_optimizer = TestOptimizer() - optimizer = opt.Bolton(internal_optimizer, loss) + optimizer = opt.BoltOn(internal_optimizer, loss) with self.assertRaises(AttributeError): getattr(optimizer, attr) diff --git a/tutorials/bolton_tutorial.py b/tutorials/bolton_tutorial.py index ae9707e..c56f9bf 100644 --- a/tutorials/bolton_tutorial.py +++ b/tutorials/bolton_tutorial.py @@ -18,7 +18,7 @@ from __future__ import print_function import tensorflow as tf # pylint: disable=wrong-import-position from privacy.bolton import losses # pylint: disable=wrong-import-position from privacy.bolton import models # pylint: disable=wrong-import-position -from privacy.bolton.optimizers import Bolton # pylint: disable=wrong-import-position +from privacy.bolton.optimizers import BoltOn # pylint: disable=wrong-import-position # ------- # First, we will create a binary classification dataset with a single output # dimension. The samples for each label are repeated data points at different @@ -39,12 +39,12 @@ generator = tf.data.Dataset.from_tensor_slices((x, y)) generator = generator.batch(10) generator = generator.shuffle(10) # ------- -# First, we will explore using the pre - built BoltonModel, which is a thin +# First, we will explore using the pre - built BoltOnModel, which is a thin # wrapper around a Keras Model using a single - layer neural network. -# It automatically uses the Bolton Optimizer which encompasses all the logic -# required for the Bolton Differential Privacy method. +# It automatically uses the BoltOn Optimizer which encompasses all the logic +# required for the BoltOn Differential Privacy method. # ------- -bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have. +bolt = models.BoltOnModel(n_outputs) # tell the model how many outputs we have. # ------- # Now, we will pick our optimizer and Strongly Convex Loss function. The loss # must extend from StrongConvexMixin and implement the associated methods.Some @@ -60,7 +60,7 @@ loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) # to be 1; these are all tunable and their impact can be read in losses. # StrongConvexBinaryCrossentropy.We then compile the model with the chosen # optimizer and loss, which will automatically wrap the chosen optimizer with -# the Bolton Optimizer, ensuring the required components function as required +# the BoltOn Optimizer, ensuring the required components function as required # for privacy guarantees. # ------- bolt.compile(optimizer, loss) @@ -77,7 +77,7 @@ bolt.compile(optimizer, loss) # 2. noise_distribution, a valid string indicating the distriution to use (must # be implemented) # -# The BoltonModel offers a helper method,.calculate_class_weight to aid in +# The BoltOnModel offers a helper method,.calculate_class_weight to aid in # class_weight calculation. # required parameters # ------- @@ -132,7 +132,7 @@ bolt.fit(generator, noise_distribution=noise_distribution, verbose=0) # ------- -# You don't have to use the bolton model to use the Bolton method. +# You don't have to use the bolton model to use the BoltOn method. # There are only a few requirements: # 1. make sure any requirements from the loss are implemented in the model. # 2. instantiate the optimizer and use it as a context around the fit operation. @@ -140,7 +140,7 @@ bolt.fit(generator, # -------------------- Part 2, using the Optimizer # ------- -# Here, we create our own model and setup the Bolton optimizer. +# Here, we create our own model and setup the BoltOn optimizer. # ------- @@ -157,7 +157,7 @@ class TestModel(tf.keras.Model): # pylint: disable=abstract-method optimizer = tf.optimizers.SGD() loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) -optimizer = Bolton(optimizer, loss) +optimizer = BoltOn(optimizer, loss) # ------- # Now, we instantiate our model and check for 1. Since our loss requires L2 # regularization over the kernel, we will pass it to the model. @@ -166,7 +166,7 @@ n_outputs = 1 # parameter for model and optimizer context. test_model = TestModel(loss.kernel_regularizer(), n_outputs) test_model.compile(optimizer, loss) # ------- -# We comply with 2., and use the Bolton Optimizer as a context around the fit +# We comply with 2., and use the BoltOn Optimizer as a context around the fit # method. # ------- # parameters for context