forked from 626_privacy/tensorflow_privacy
Code style and documentation changes.
This commit is contained in:
parent
fb12ee047f
commit
2065f2b16a
10 changed files with 105 additions and 103 deletions
|
@ -42,8 +42,8 @@ else:
|
|||
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||
|
||||
from privacy.bolton.models import BoltonModel
|
||||
from privacy.bolton.optimizers import Bolton
|
||||
from privacy.bolton.models import BoltOnModel
|
||||
from privacy.bolton.optimizers import BoltOn
|
||||
from privacy.bolton.losses import StrongConvexMixin
|
||||
from privacy.bolton.losses import StrongConvexBinaryCrossentropy
|
||||
from privacy.bolton.losses import StrongConvexHuber
|
||||
|
|
|
@ -1,25 +1,26 @@
|
|||
# Bolton Subpackage
|
||||
# BoltOn Subpackage
|
||||
|
||||
This package contains source code for the Bolton method. This method is a subset
|
||||
of methods used in the ensuring privacy in machine learning that leverages
|
||||
This package contains source code for the BoltOn method, a particular
|
||||
differential-privacy (DP) technique that uses output perturbations and leverages
|
||||
additional assumptions to provide a new way of approaching the privacy
|
||||
guarantees.
|
||||
|
||||
## Bolton Description
|
||||
## BoltOn Description
|
||||
|
||||
This method uses 4 key steps to achieve privacy guarantees:
|
||||
1. Adds noise to weights after training (output perturbation).
|
||||
2. Projects weights to R after each batch
|
||||
2. Projects weights to R, the radius of the hypothesis space,
|
||||
after each batch. This value is configurable by the user.
|
||||
3. Limits learning rate
|
||||
4. Use a strongly convex loss function (see compile)
|
||||
|
||||
For more details on the strong convexity requirements, see:
|
||||
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||
Descent-based Analytics by Xi Wu et al.
|
||||
Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf
|
||||
|
||||
## Why Bolton?
|
||||
## Why BoltOn?
|
||||
|
||||
The major difference for the Bolton method is that it injects noise post model
|
||||
The major difference for the BoltOn method is that it injects noise post model
|
||||
convergence, rather than noising gradients or weights during training. This
|
||||
approach requires some additional constraints listed in the Description.
|
||||
Should the use-case and model satisfy these constraints, this is another
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bolton Method for privacy."""
|
||||
"""BoltOn Method for privacy."""
|
||||
import sys
|
||||
from distutils.version import LooseVersion
|
||||
import tensorflow as tf
|
||||
|
@ -23,7 +23,7 @@ if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
|
|||
if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
|
||||
pass
|
||||
else:
|
||||
from privacy.bolton.models import BoltonModel # pylint: disable=g-import-not-at-top
|
||||
from privacy.bolton.optimizers import Bolton # pylint: disable=g-import-not-at-top
|
||||
from privacy.bolton.models import BoltOnModel # pylint: disable=g-import-not-at-top
|
||||
from privacy.bolton.optimizers import BoltOn # pylint: disable=g-import-not-at-top
|
||||
from privacy.bolton.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top
|
||||
from privacy.bolton.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top
|
||||
|
|
|
@ -29,7 +29,7 @@ class StrongConvexMixin: # pylint: disable=old-style-class
|
|||
"""Strong Convex Mixin base class.
|
||||
|
||||
Strong Convex Mixin base class for any loss function that will be used with
|
||||
Bolton model. Subclasses must be strongly convex and implement the
|
||||
BoltOn model. Subclasses must be strongly convex and implement the
|
||||
associated constants. They must also conform to the requirements of tf losses
|
||||
(see super class).
|
||||
|
||||
|
|
|
@ -372,7 +372,7 @@ class HuberTests(keras_parameterized.TestCase):
|
|||
Args:
|
||||
logits: unscaled output of model
|
||||
y_true: label
|
||||
delta:
|
||||
delta: delta value for StrongConvexHuber loss.
|
||||
result: correct loss calculation value
|
||||
"""
|
||||
logits = tf.Variable(logits, False, dtype=tf.float32)
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bolton model for bolton method of differentially private ML."""
|
||||
"""BoltOn model for bolton method of differentially private ML."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -21,11 +21,11 @@ from tensorflow.python.framework import ops as _ops
|
|||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.models import Model
|
||||
from privacy.bolton.losses import StrongConvexMixin
|
||||
from privacy.bolton.optimizers import Bolton
|
||||
from privacy.bolton.optimizers import BoltOn
|
||||
|
||||
|
||||
class BoltonModel(Model): # pylint: disable=abstract-method
|
||||
"""Bolton episilon-delta differential privacy model.
|
||||
class BoltOnModel(Model): # pylint: disable=abstract-method
|
||||
"""BoltOn episilon-delta differential privacy model.
|
||||
|
||||
The privacy guarantees are dependent on the noise that is sampled. Please
|
||||
see the paper linked below for more details.
|
||||
|
@ -52,7 +52,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
seed: random seed to use
|
||||
dtype: data type to use for tensors
|
||||
"""
|
||||
super(BoltonModel, self).__init__(name='bolton', dynamic=False)
|
||||
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
|
||||
if n_outputs <= 0:
|
||||
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
|
||||
n_outputs
|
||||
|
@ -69,6 +69,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
inputs: inputs to neural network
|
||||
|
||||
Returns:
|
||||
Output logits for the given inputs.
|
||||
|
||||
"""
|
||||
return self.output_layer(inputs)
|
||||
|
@ -78,11 +79,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
loss,
|
||||
kernel_initializer=tf.initializers.GlorotUniform,
|
||||
**kwargs): # pylint: disable=arguments-differ
|
||||
"""See super class. Default optimizer used in Bolton method is SGD.
|
||||
"""See super class. Default optimizer used in BoltOn method is SGD.
|
||||
|
||||
Args:
|
||||
optimizer: The optimizer to use. This will be automatically wrapped
|
||||
with the Bolton Optimizer.
|
||||
with the BoltOn Optimizer.
|
||||
loss: The loss function to use. Must be a StrongConvex loss (extend the
|
||||
StrongConvexMixin).
|
||||
kernel_initializer: The kernel initializer to use for the single layer.
|
||||
|
@ -99,11 +100,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
kernel_initializer=kernel_initializer(),
|
||||
)
|
||||
self._layers_instantiated = True
|
||||
if not isinstance(optimizer, Bolton):
|
||||
if not isinstance(optimizer, BoltOn):
|
||||
optimizer = optimizers.get(optimizer)
|
||||
optimizer = Bolton(optimizer, loss)
|
||||
optimizer = BoltOn(optimizer, loss)
|
||||
|
||||
super(BoltonModel, self).compile(optimizer, loss=loss, **kwargs)
|
||||
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
|
||||
|
||||
def fit(self,
|
||||
x=None,
|
||||
|
@ -115,7 +116,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
noise_distribution='laplace',
|
||||
steps_per_epoch=None,
|
||||
**kwargs): # pylint: disable=arguments-differ
|
||||
"""Reroutes to super fit with Bolton delta-epsilon privacy requirements.
|
||||
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
|
||||
|
||||
Note, inputs must be normalized s.t. ||x|| < 1.
|
||||
Requirements are as follows:
|
||||
|
@ -126,9 +127,9 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
See super implementation for more details.
|
||||
|
||||
Args:
|
||||
x:
|
||||
y:
|
||||
batch_size:
|
||||
x: Inputs to fit on, see super.
|
||||
y: Labels to fit on, see super.
|
||||
batch_size: The batch size to use for training, see super.
|
||||
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||
whose dim == n_classes.
|
||||
n_samples: the number of individual samples in x.
|
||||
|
@ -139,7 +140,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
**kwargs: kwargs to keras Model.fit. See super.
|
||||
|
||||
Returns:
|
||||
output
|
||||
Output from super fit method.
|
||||
"""
|
||||
if class_weight is None:
|
||||
class_weight_ = self.calculate_class_weights(class_weight)
|
||||
|
@ -170,7 +171,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
class_weight_,
|
||||
data_size,
|
||||
batch_size_) as _:
|
||||
out = super(BoltonModel, self).fit(x=x,
|
||||
out = super(BoltOnModel, self).fit(x=x,
|
||||
y=y,
|
||||
batch_size=batch_size,
|
||||
class_weight=class_weight,
|
||||
|
@ -192,18 +193,18 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
is a generator. See super method and fit for more details.
|
||||
|
||||
Args:
|
||||
generator:
|
||||
generator: Inputs generator following Tensorflow guidelines, see super.
|
||||
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||
whose dim == n_classes.
|
||||
noise_distribution: the distribution to get noise from.
|
||||
epsilon: privacy parameter, which trades off utility and privacy. See
|
||||
Bolton paper for more description.
|
||||
BoltOn paper for more description.
|
||||
n_samples: number of individual samples in x
|
||||
steps_per_epoch:
|
||||
steps_per_epoch: Number of steps per training epoch, see super.
|
||||
**kwargs: **kwargs
|
||||
|
||||
Returns:
|
||||
output
|
||||
Output from super fit_generator method.
|
||||
"""
|
||||
if class_weight is None:
|
||||
class_weight = self.calculate_class_weights(class_weight)
|
||||
|
@ -224,7 +225,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
|
|||
class_weight,
|
||||
data_size,
|
||||
batch_size) as _:
|
||||
out = super(BoltonModel, self).fit_generator(
|
||||
out = super(BoltOnModel, self).fit_generator(
|
||||
generator,
|
||||
class_weight=class_weight,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
|
|
|
@ -26,11 +26,11 @@ from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
|||
from tensorflow.python.keras.regularizers import L1L2
|
||||
from privacy.bolton import models
|
||||
from privacy.bolton.losses import StrongConvexMixin
|
||||
from privacy.bolton.optimizers import Bolton
|
||||
from privacy.bolton.optimizers import BoltOn
|
||||
|
||||
|
||||
class TestLoss(losses.Loss, StrongConvexMixin):
|
||||
"""Test loss function for testing Bolton model."""
|
||||
"""Test loss function for testing BoltOn model."""
|
||||
|
||||
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
|
||||
super(TestLoss, self).__init__(name=name)
|
||||
|
@ -105,7 +105,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
|
||||
|
||||
class TestOptimizer(OptimizerV2):
|
||||
"""Test optimizer used for testing Bolton model."""
|
||||
"""Test optimizer used for testing BoltOn model."""
|
||||
|
||||
def __init__(self):
|
||||
super(TestOptimizer, self).__init__('test')
|
||||
|
@ -138,14 +138,14 @@ class InitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_init_params(self, n_outputs):
|
||||
"""Test initialization of BoltonModel.
|
||||
"""Test initialization of BoltOnModel.
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
"""
|
||||
# test valid domains for each variable
|
||||
clf = models.BoltonModel(n_outputs)
|
||||
self.assertIsInstance(clf, models.BoltonModel)
|
||||
clf = models.BoltOnModel(n_outputs)
|
||||
self.assertIsInstance(clf, models.BoltOnModel)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'invalid n_outputs',
|
||||
|
@ -153,14 +153,14 @@ class InitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_bad_init_params(self, n_outputs):
|
||||
"""test bad initializations of BoltonModel that should raise errors.
|
||||
"""test bad initializations of BoltOnModel that should raise errors.
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
"""
|
||||
# test invalid domains for each variable, especially noise
|
||||
with self.assertRaises(ValueError):
|
||||
models.BoltonModel(n_outputs)
|
||||
models.BoltOnModel(n_outputs)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'string compile',
|
||||
|
@ -175,7 +175,7 @@ class InitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_compile(self, n_outputs, loss, optimizer):
|
||||
"""Test compilation of BoltonModel.
|
||||
"""Test compilation of BoltOnModel.
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
|
@ -184,7 +184,7 @@ class InitTests(keras_parameterized.TestCase):
|
|||
"""
|
||||
# test compilation of valid tf.optimizer and tf.loss
|
||||
with self.cached_session():
|
||||
clf = models.BoltonModel(n_outputs)
|
||||
clf = models.BoltOnModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
self.assertEqual(clf.loss, loss)
|
||||
|
||||
|
@ -201,7 +201,7 @@ class InitTests(keras_parameterized.TestCase):
|
|||
}
|
||||
])
|
||||
def test_bad_compile(self, n_outputs, loss, optimizer):
|
||||
"""test bad compilations of BoltonModel that should raise errors.
|
||||
"""test bad compilations of BoltOnModel that should raise errors.
|
||||
|
||||
Args:
|
||||
n_outputs: number of output neurons
|
||||
|
@ -211,7 +211,7 @@ class InitTests(keras_parameterized.TestCase):
|
|||
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
||||
with self.cached_session():
|
||||
with self.assertRaises((ValueError, AttributeError)):
|
||||
clf = models.BoltonModel(n_outputs)
|
||||
clf = models.BoltOnModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
|
||||
|
||||
|
@ -276,9 +276,9 @@ def _do_fit(n_samples,
|
|||
distribution: distribution to get noise from.
|
||||
|
||||
Returns:
|
||||
BoltonModel instsance
|
||||
BoltOnModel instsance
|
||||
"""
|
||||
clf = models.BoltonModel(n_outputs)
|
||||
clf = models.BoltOnModel(n_outputs)
|
||||
clf.compile(optimizer, loss)
|
||||
if generator:
|
||||
x = _cat_dataset(
|
||||
|
@ -328,14 +328,14 @@ class FitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_fit(self, generator, reset_n_samples):
|
||||
"""Tests fitting of BoltonModel.
|
||||
"""Tests fitting of BoltOnModel.
|
||||
|
||||
Args:
|
||||
generator: True for generator test, False for iterator test.
|
||||
reset_n_samples: True to reset the n_samples to None, False does nothing
|
||||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = Bolton(TestOptimizer(), loss)
|
||||
optimizer = BoltOn(TestOptimizer(), loss)
|
||||
n_classes = 2
|
||||
input_dim = 5
|
||||
epsilon = 1
|
||||
|
@ -360,7 +360,7 @@ class FitTests(keras_parameterized.TestCase):
|
|||
},
|
||||
])
|
||||
def test_fit_gen(self, generator):
|
||||
"""Tests the fit_generator method of BoltonModel.
|
||||
"""Tests the fit_generator method of BoltOnModel.
|
||||
|
||||
Args:
|
||||
generator: True to test with a generator dataset
|
||||
|
@ -371,7 +371,7 @@ class FitTests(keras_parameterized.TestCase):
|
|||
input_dim = 5
|
||||
batch_size = 1
|
||||
n_samples = 10
|
||||
clf = models.BoltonModel(n_classes)
|
||||
clf = models.BoltOnModel(n_classes)
|
||||
clf.compile(optimizer, loss)
|
||||
x = _cat_dataset(
|
||||
n_samples,
|
||||
|
@ -456,7 +456,7 @@ class FitTests(keras_parameterized.TestCase):
|
|||
num_classes: number of outputs neurons
|
||||
result: expected result
|
||||
"""
|
||||
clf = models.BoltonModel(1, 1)
|
||||
clf = models.BoltOnModel(1, 1)
|
||||
expected = clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
num_classes)
|
||||
|
@ -523,7 +523,7 @@ class FitTests(keras_parameterized.TestCase):
|
|||
num_classes: number of outputs neurons.
|
||||
err_msg: The expected error message.
|
||||
"""
|
||||
clf = models.BoltonModel(1, 1)
|
||||
clf = models.BoltOnModel(1, 1)
|
||||
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||
clf.calculate_class_weights(class_weights,
|
||||
class_counts,
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Bolton Optimizer for bolton method."""
|
||||
"""BoltOn Optimizer for bolton method."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -45,15 +45,15 @@ class GammaBetaDecreasingStep(
|
|||
|
||||
Returns:
|
||||
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
|
||||
the Bolton privacy requirements.
|
||||
the BoltOn privacy requirements.
|
||||
"""
|
||||
if not self.is_init:
|
||||
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
|
||||
'This is performed automatically by using the '
|
||||
'{1} as a context manager, '
|
||||
'as desired'.format(self.__class__.__name__,
|
||||
Bolton.__class__.__name__
|
||||
)
|
||||
BoltOn.__class__.__name__
|
||||
)
|
||||
)
|
||||
dtype = self.beta.dtype
|
||||
one = tf.constant(1, dtype)
|
||||
|
@ -86,22 +86,22 @@ class GammaBetaDecreasingStep(
|
|||
self.gamma = None
|
||||
|
||||
|
||||
class Bolton(optimizer_v2.OptimizerV2):
|
||||
"""Wrap another tf optimizer with Bolton privacy protocol.
|
||||
class BoltOn(optimizer_v2.OptimizerV2):
|
||||
"""Wrap another tf optimizer with BoltOn privacy protocol.
|
||||
|
||||
Bolton optimizer wraps another tf optimizer to be used
|
||||
BoltOn optimizer wraps another tf optimizer to be used
|
||||
as the visible optimizer to the tf model. No matter the optimizer
|
||||
passed, "Bolton" enables the bolton model to control the learning rate
|
||||
passed, "BoltOn" enables the bolton model to control the learning rate
|
||||
based on the strongly convex loss.
|
||||
|
||||
To use the Bolton method, you must:
|
||||
To use the BoltOn method, you must:
|
||||
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
|
||||
2. use it as a context manager around your .fit method internals.
|
||||
|
||||
This can be accomplished by the following:
|
||||
optimizer = tf.optimizers.SGD()
|
||||
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
|
||||
bolton = Bolton(optimizer, loss)
|
||||
bolton = BoltOn(optimizer, loss)
|
||||
with bolton(*args) as _:
|
||||
model.fit()
|
||||
The args required for the context manager can be found in the __call__
|
||||
|
@ -142,7 +142,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
'_is_init'
|
||||
]
|
||||
self._internal_optimizer = optimizer
|
||||
self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning
|
||||
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
|
||||
# rate scheduler, as required for privacy guarantees. This will still need
|
||||
# to get values from the loss function near the time that .fit is called
|
||||
# on the model (when this optimizer will be called as a context manager)
|
||||
|
@ -162,7 +162,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
False to check if weights > R-ball and only normalize then.
|
||||
|
||||
Raises:
|
||||
Exception:
|
||||
Exception: If not called from inside this optimizer context.
|
||||
"""
|
||||
if not self._is_init:
|
||||
raise Exception('This method must be called from within the optimizer\'s '
|
||||
|
@ -190,7 +190,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
Noise in shape of layer's weights to be added to the weights.
|
||||
|
||||
Raises:
|
||||
Exception:
|
||||
Exception: If not called from inside this optimizer's context.
|
||||
"""
|
||||
if not self._is_init:
|
||||
raise Exception('This method must be called from within the optimizer\'s '
|
||||
|
@ -234,10 +234,10 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
from the _internal_optimizer instance.
|
||||
|
||||
Args:
|
||||
name:
|
||||
name: Name of attribute to get from this or aggregate optimizer.
|
||||
|
||||
Returns:
|
||||
attribute from Bolton if specified to come from self, else
|
||||
attribute from BoltOn if specified to come from self, else
|
||||
from _internal_optimizer.
|
||||
"""
|
||||
if name == '_private_attributes' or name in self._private_attributes:
|
||||
|
@ -336,7 +336,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
batch_size: batch size used.
|
||||
|
||||
Returns:
|
||||
self
|
||||
self, to be used by the __enter__ method for context.
|
||||
"""
|
||||
if epsilon <= 0:
|
||||
raise ValueError('Detected epsilon: {0}. '
|
||||
|
@ -361,7 +361,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
|
||||
Used to:
|
||||
1.reset the model and fit parameters passed to the optimizer
|
||||
to enable the Bolton Privacy guarantees. These are reset to ensure
|
||||
to enable the BoltOn Privacy guarantees. These are reset to ensure
|
||||
that any future calls to fit with the same instance of the optimizer
|
||||
will properly error out.
|
||||
|
||||
|
@ -369,7 +369,7 @@ class Bolton(optimizer_v2.OptimizerV2):
|
|||
adding noise to the weights.
|
||||
|
||||
Args:
|
||||
*args: *args
|
||||
*args: encompasses the type, value, and traceback values which are unused.
|
||||
"""
|
||||
self.project_weights_to_r(True)
|
||||
for layer in self.layers:
|
||||
|
|
|
@ -33,7 +33,7 @@ from privacy.bolton.losses import StrongConvexMixin
|
|||
|
||||
|
||||
class TestModel(Model): # pylint: disable=abstract-method
|
||||
"""Bolton episilon-delta model.
|
||||
"""BoltOn episilon-delta model.
|
||||
|
||||
Uses 4 key steps to achieve privacy guarantees:
|
||||
1. Adds noise to weights after training (output perturbation).
|
||||
|
@ -66,7 +66,7 @@ class TestModel(Model): # pylint: disable=abstract-method
|
|||
|
||||
|
||||
class TestLoss(losses.Loss, StrongConvexMixin):
|
||||
"""Test loss function for testing Bolton model."""
|
||||
"""Test loss function for testing BoltOn model."""
|
||||
|
||||
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
|
||||
super(TestLoss, self).__init__(name=name)
|
||||
|
@ -142,7 +142,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
|
|||
|
||||
|
||||
class TestOptimizer(OptimizerV2):
|
||||
"""Optimizer used for testing the Bolton optimizer."""
|
||||
"""Optimizer used for testing the BoltOn optimizer."""
|
||||
|
||||
def __init__(self):
|
||||
super(TestOptimizer, self).__init__('test')
|
||||
|
@ -185,7 +185,7 @@ class TestOptimizer(OptimizerV2):
|
|||
|
||||
|
||||
class BoltonOptimizerTest(keras_parameterized.TestCase):
|
||||
"""Bolton Optimizer tests."""
|
||||
"""BoltOn Optimizer tests."""
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@parameterized.named_parameters([
|
||||
{'testcase_name': 'getattr',
|
||||
|
@ -201,19 +201,19 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
])
|
||||
|
||||
def test_fn(self, fn, args, result, test_attr):
|
||||
"""test that a fn of Bolton optimizer is working as expected.
|
||||
"""test that a fn of BoltOn optimizer is working as expected.
|
||||
|
||||
Args:
|
||||
fn: method of Optimizer to test
|
||||
args: args to optimizer fn
|
||||
result: the expected result
|
||||
test_attr: None if the fn returns the test result. Otherwise, this is
|
||||
the attribute of Bolton to check against result with.
|
||||
the attribute of BoltOn to check against result with.
|
||||
|
||||
"""
|
||||
tf.random.set_seed(1)
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(1)
|
||||
model.layers[0].kernel = \
|
||||
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||
|
@ -260,7 +260,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
'result': [[1]]},
|
||||
])
|
||||
def test_project(self, r, shape, n_out, init_value, result):
|
||||
"""test that a fn of Bolton optimizer is working as expected.
|
||||
"""test that a fn of BoltOn optimizer is working as expected.
|
||||
|
||||
Args:
|
||||
r: Radius value for StrongConvex loss function.
|
||||
|
@ -273,7 +273,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
@tf.function
|
||||
def project_fn(r):
|
||||
loss = TestLoss(1, 1, r)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(n_out, shape, init_value)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -308,7 +308,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
@tf.function
|
||||
def test_run():
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -343,7 +343,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
@tf.function
|
||||
def test_run(noise, epsilon):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -371,7 +371,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -415,7 +415,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""Tests rerouted function.
|
||||
|
||||
Tests that a method of the internal optimizer is correctly routed from
|
||||
the Bolton instance to the internal optimizer instance (TestOptimizer,
|
||||
the BoltOn instance to the internal optimizer instance (TestOptimizer,
|
||||
here).
|
||||
|
||||
Args:
|
||||
|
@ -424,7 +424,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
optimizer = TestOptimizer()
|
||||
bolton = opt.Bolton(optimizer, loss)
|
||||
bolton = opt.BoltOn(optimizer, loss)
|
||||
model = TestModel(3)
|
||||
model.compile(optimizer, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -465,7 +465,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
@tf.function
|
||||
def test_run(fn, args):
|
||||
loss = TestLoss(1, 1, 1)
|
||||
bolton = opt.Bolton(TestOptimizer(), loss)
|
||||
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||
model = TestModel(1, (1,), 1)
|
||||
model.compile(bolton, loss)
|
||||
model.layers[0].kernel = \
|
||||
|
@ -502,7 +502,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
internal_optimizer = TestOptimizer()
|
||||
optimizer = opt.Bolton(internal_optimizer, loss)
|
||||
optimizer = opt.BoltOn(internal_optimizer, loss)
|
||||
self.assertEqual(getattr(optimizer, attr),
|
||||
getattr(internal_optimizer, attr))
|
||||
|
||||
|
@ -521,7 +521,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
|
|||
"""
|
||||
loss = TestLoss(1, 1, 1)
|
||||
internal_optimizer = TestOptimizer()
|
||||
optimizer = opt.Bolton(internal_optimizer, loss)
|
||||
optimizer = opt.BoltOn(internal_optimizer, loss)
|
||||
with self.assertRaises(AttributeError):
|
||||
getattr(optimizer, attr)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import print_function
|
|||
import tensorflow as tf # pylint: disable=wrong-import-position
|
||||
from privacy.bolton import losses # pylint: disable=wrong-import-position
|
||||
from privacy.bolton import models # pylint: disable=wrong-import-position
|
||||
from privacy.bolton.optimizers import Bolton # pylint: disable=wrong-import-position
|
||||
from privacy.bolton.optimizers import BoltOn # pylint: disable=wrong-import-position
|
||||
# -------
|
||||
# First, we will create a binary classification dataset with a single output
|
||||
# dimension. The samples for each label are repeated data points at different
|
||||
|
@ -39,12 +39,12 @@ generator = tf.data.Dataset.from_tensor_slices((x, y))
|
|||
generator = generator.batch(10)
|
||||
generator = generator.shuffle(10)
|
||||
# -------
|
||||
# First, we will explore using the pre - built BoltonModel, which is a thin
|
||||
# First, we will explore using the pre - built BoltOnModel, which is a thin
|
||||
# wrapper around a Keras Model using a single - layer neural network.
|
||||
# It automatically uses the Bolton Optimizer which encompasses all the logic
|
||||
# required for the Bolton Differential Privacy method.
|
||||
# It automatically uses the BoltOn Optimizer which encompasses all the logic
|
||||
# required for the BoltOn Differential Privacy method.
|
||||
# -------
|
||||
bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have.
|
||||
bolt = models.BoltOnModel(n_outputs) # tell the model how many outputs we have.
|
||||
# -------
|
||||
# Now, we will pick our optimizer and Strongly Convex Loss function. The loss
|
||||
# must extend from StrongConvexMixin and implement the associated methods.Some
|
||||
|
@ -60,7 +60,7 @@ loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
|||
# to be 1; these are all tunable and their impact can be read in losses.
|
||||
# StrongConvexBinaryCrossentropy.We then compile the model with the chosen
|
||||
# optimizer and loss, which will automatically wrap the chosen optimizer with
|
||||
# the Bolton Optimizer, ensuring the required components function as required
|
||||
# the BoltOn Optimizer, ensuring the required components function as required
|
||||
# for privacy guarantees.
|
||||
# -------
|
||||
bolt.compile(optimizer, loss)
|
||||
|
@ -77,7 +77,7 @@ bolt.compile(optimizer, loss)
|
|||
# 2. noise_distribution, a valid string indicating the distriution to use (must
|
||||
# be implemented)
|
||||
#
|
||||
# The BoltonModel offers a helper method,.calculate_class_weight to aid in
|
||||
# The BoltOnModel offers a helper method,.calculate_class_weight to aid in
|
||||
# class_weight calculation.
|
||||
# required parameters
|
||||
# -------
|
||||
|
@ -132,7 +132,7 @@ bolt.fit(generator,
|
|||
noise_distribution=noise_distribution,
|
||||
verbose=0)
|
||||
# -------
|
||||
# You don't have to use the bolton model to use the Bolton method.
|
||||
# You don't have to use the bolton model to use the BoltOn method.
|
||||
# There are only a few requirements:
|
||||
# 1. make sure any requirements from the loss are implemented in the model.
|
||||
# 2. instantiate the optimizer and use it as a context around the fit operation.
|
||||
|
@ -140,7 +140,7 @@ bolt.fit(generator,
|
|||
# -------------------- Part 2, using the Optimizer
|
||||
|
||||
# -------
|
||||
# Here, we create our own model and setup the Bolton optimizer.
|
||||
# Here, we create our own model and setup the BoltOn optimizer.
|
||||
# -------
|
||||
|
||||
|
||||
|
@ -157,7 +157,7 @@ class TestModel(tf.keras.Model): # pylint: disable=abstract-method
|
|||
|
||||
optimizer = tf.optimizers.SGD()
|
||||
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||
optimizer = Bolton(optimizer, loss)
|
||||
optimizer = BoltOn(optimizer, loss)
|
||||
# -------
|
||||
# Now, we instantiate our model and check for 1. Since our loss requires L2
|
||||
# regularization over the kernel, we will pass it to the model.
|
||||
|
@ -166,7 +166,7 @@ n_outputs = 1 # parameter for model and optimizer context.
|
|||
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
|
||||
test_model.compile(optimizer, loss)
|
||||
# -------
|
||||
# We comply with 2., and use the Bolton Optimizer as a context around the fit
|
||||
# We comply with 2., and use the BoltOn Optimizer as a context around the fit
|
||||
# method.
|
||||
# -------
|
||||
# parameters for context
|
||||
|
|
Loading…
Reference in a new issue