many fixes

This commit is contained in:
npapernot 2019-07-25 15:37:54 +00:00
parent fe90e3c596
commit 8e6bcf9b4a
6 changed files with 135 additions and 144 deletions

View file

@ -20,7 +20,7 @@ if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
raise ImportError("Please upgrade your version " raise ImportError("Please upgrade your version "
"of tensorflow from: {0} to at least 2.0.0 to " "of tensorflow from: {0} to at least 2.0.0 to "
"use privacy/bolton".format(LooseVersion(tf.__version__))) "use privacy/bolton".format(LooseVersion(tf.__version__)))
if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
pass pass
else: else:
from privacy.bolton.models import BoltonModel from privacy.bolton.models import BoltonModel

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unit testing for losses.py""" """Unit testing for losses."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -20,11 +20,11 @@ from __future__ import print_function
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
import sys import sys
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.framework import test_util
from tensorflow.python.keras.regularizers import L1L2
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from privacy.bolton.losses import StrongConvexBinaryCrossentropy from privacy.bolton.losses import StrongConvexBinaryCrossentropy
from privacy.bolton.losses import StrongConvexHuber from privacy.bolton.losses import StrongConvexHuber
from privacy.bolton.losses import StrongConvexMixin from privacy.bolton.losses import StrongConvexMixin
@ -43,7 +43,7 @@ def captured_output():
class StrongConvexMixinTests(keras_parameterized.TestCase): class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin""" """Tests for the StrongConvexMixin."""
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'beta not implemented', {'testcase_name': 'beta not implemented',
'fn': 'beta', 'fn': 'beta',
@ -58,6 +58,7 @@ class StrongConvexMixinTests(keras_parameterized.TestCase):
'fn': 'radius', 'fn': 'radius',
'args': []}, 'args': []},
]) ])
def test_not_implemented(self, fn, args): def test_not_implemented(self, fn, args):
"""Test that the given fn's are not implemented on the mixin. """Test that the given fn's are not implemented on the mixin.
@ -75,7 +76,7 @@ class StrongConvexMixinTests(keras_parameterized.TestCase):
'args': []}, 'args': []},
]) ])
def test_return_none(self, fn, args): def test_return_none(self, fn, args):
"""Test that fn of Mixin returns None """Test that fn of Mixin returns None.
Args: Args:
fn: fn of Mixin to test fn: fn of Mixin to test
@ -126,7 +127,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, # pylint: disable=invalid-name }, # pylint: disable=invalid-name
]) ])
def test_bad_init_params(self, reg_lambda, C, radius_constant): def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError """Test invalid domain for given params. Should return ValueError.
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg
@ -162,7 +163,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, },
]) ])
def test_calculation(self, logits, y_true, result): def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value """Test the call method to ensure it returns the correct value.
Args: Args:
logits: unscaled output of model logits: unscaled output of model
@ -202,7 +203,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, },
]) ])
def test_fns(self, init_args, fn, args, result): def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result """Test that fn of BinaryCrossentropy loss returns the correct result.
Args: Args:
init_args: init values for loss instance init_args: init values for loss instance
@ -245,7 +246,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
class HuberTests(keras_parameterized.TestCase): class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss""" """tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'normal', {'testcase_name': 'normal',
@ -256,7 +257,7 @@ class HuberTests(keras_parameterized.TestCase):
}, },
]) ])
def test_init_params(self, reg_lambda, c, radius_constant, delta): def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments """Test initialization for given arguments.
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg

View file

@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Bolton model for bolton method of differentially private ML""" """Bolton model for bolton method of differentially private ML."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import optimizers
from tensorflow.python.framework import ops as _ops from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import Model
from privacy.bolton.losses import StrongConvexMixin from privacy.bolton.losses import StrongConvexMixin
from privacy.bolton.optimizers import Bolton from privacy.bolton.optimizers import Bolton
@ -44,9 +44,8 @@ class BoltonModel(Model): # pylint: disable=abstract-method
def __init__(self, def __init__(self,
n_outputs, n_outputs,
seed=1, seed=1,
dtype=tf.float32 dtype=tf.float32):
): """Private constructor.
""" private constructor.
Args: Args:
n_outputs: number of output classes to predict. n_outputs: number of output classes to predict.
@ -64,7 +63,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
self._dtype = dtype self._dtype = dtype
def call(self, inputs): # pylint: disable=arguments-differ def call(self, inputs): # pylint: disable=arguments-differ
"""Forward pass of network """Forward pass of network.
Args: Args:
inputs: inputs to neural network inputs: inputs to neural network
@ -111,8 +110,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
weighted_metrics=weighted_metrics, weighted_metrics=weighted_metrics,
target_tensors=target_tensors, target_tensors=target_tensors,
distribute=distribute, distribute=distribute,
**kwargs **kwargs)
)
def fit(self, def fit(self,
x=None, x=None,
@ -158,8 +156,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
data_size = None data_size = None
batch_size_ = self._validate_or_infer_batch_size(batch_size, batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch, steps_per_epoch,
x x)
)
# inferring batch_size to be passed to optimizer. batch_size must remain its # inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit() # initial value when passed to super().fit()
if batch_size_ is None: if batch_size_ is None:
@ -173,15 +170,13 @@ class BoltonModel(Model): # pylint: disable=abstract-method
self.layers, self.layers,
class_weight_, class_weight_,
data_size, data_size,
batch_size_, batch_size_) as _:
) as _:
out = super(BoltonModel, self).fit(x=x, out = super(BoltonModel, self).fit(x=x,
y=y, y=y,
batch_size=batch_size, batch_size=batch_size,
class_weight=class_weight, class_weight=class_weight,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
**kwargs **kwargs)
)
return out return out
def fit_generator(self, def fit_generator(self,
@ -191,8 +186,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
epsilon=2, epsilon=2,
n_samples=None, n_samples=None,
steps_per_epoch=None, steps_per_epoch=None,
**kwargs **kwargs): # pylint: disable=arguments-differ
): # pylint: disable=arguments-differ
""" """
This method is the same as fit except for when the passed dataset This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details. is a generator. See super method and fit for more details.
@ -218,28 +212,24 @@ class BoltonModel(Model): # pylint: disable=abstract-method
data_size = None data_size = None
batch_size = self._validate_or_infer_batch_size(None, batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch, steps_per_epoch,
generator generator)
)
with self.optimizer(noise_distribution, with self.optimizer(noise_distribution,
epsilon, epsilon,
self.layers, self.layers,
class_weight, class_weight,
data_size, data_size,
batch_size batch_size) as _:
) as _:
out = super(BoltonModel, self).fit_generator( out = super(BoltonModel, self).fit_generator(
generator, generator,
class_weight=class_weight, class_weight=class_weight,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
**kwargs **kwargs)
)
return out return out
def calculate_class_weights(self, def calculate_class_weights(self,
class_weights=None, class_weights=None,
class_counts=None, class_counts=None,
num_classes=None num_classes=None):
):
"""Calculates class weighting to be used in training. """Calculates class weighting to be used in training.
Args: Args:
@ -283,10 +273,8 @@ class BoltonModel(Model): # pylint: disable=abstract-method
elif is_string and class_weights == 'balanced': elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts) num_samples = sum(class_counts)
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes, weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
class_counts, class_counts),
), self._dtype)
self._dtype
)
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \ class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
tf.Variable(weighted_counts, dtype=self._dtype) tf.Variable(weighted_counts, dtype=self._dtype)
else: else:
@ -298,7 +286,5 @@ class BoltonModel(Model): # pylint: disable=abstract-method
raise ValueError( raise ValueError(
"Detected array length: {0} instead of: {1}".format( "Detected array length: {0} instead of: {1}".format(
class_weights.shape[0], class_weights.shape[0],
num_classes num_classes))
)
)
return class_weights return class_weights

View file

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unit testing for models.py""" """Unit testing for models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -29,8 +29,10 @@ from privacy.bolton import models
from privacy.bolton.optimizers import Bolton from privacy.bolton.optimizers import Bolton
from privacy.bolton.losses import StrongConvexMixin from privacy.bolton.losses import StrongConvexMixin
class TestLoss(losses.Loss, StrongConvexMixin): class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model""" """Test loss function for testing Bolton model."""
def __init__(self, reg_lambda, C, radius_constant, name='test'): def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name) super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
@ -103,6 +105,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
class TestOptimizer(OptimizerV2): class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing Bolton model""" """Test optimizer used for testing Bolton model"""
def __init__(self): def __init__(self):
super(TestOptimizer, self).__init__('test') super(TestOptimizer, self).__init__('test')
@ -293,8 +296,7 @@ def _do_fit(n_samples,
batch_size=batch_size, batch_size=batch_size,
n_samples=n_samples, n_samples=n_samples,
noise_distribution=distribution, noise_distribution=distribution,
epsilon=epsilon epsilon=epsilon)
)
return clf return clf
@ -453,8 +455,7 @@ class FitTests(keras_parameterized.TestCase):
clf = models.BoltonModel(1, 1) clf = models.BoltonModel(1, 1)
expected = clf.calculate_class_weights(class_weights, expected = clf.calculate_class_weights(class_weights,
class_counts, class_counts,
num_classes num_classes)
)
if hasattr(expected, 'numpy'): if hasattr(expected, 'numpy'):
expected = expected.numpy() expected = expected.numpy()
@ -467,13 +468,13 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': 'not_valid', 'class_weights': 'not_valid',
'class_counts': 1, 'class_counts': 1,
'num_classes': 1, 'num_classes': 1,
'err_msg': "Detected string class_weights with value: not_valid"}, 'err_msg': 'Detected string class_weights with value: not_valid'},
{'testcase_name': 'no class counts', {'testcase_name': 'no class counts',
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': None, 'class_counts': None,
'num_classes': 1, 'num_classes': 1,
'err_msg': "Class counts must be provided if " 'err_msg': 'Class counts must be provided if '
"using class_weights=balanced"}, 'using class_weights=balanced'},
{'testcase_name': 'no num classes', {'testcase_name': 'no num classes',
'class_weights': 'balanced', 'class_weights': 'balanced',
'class_counts': [1], 'class_counts': [1],
@ -489,8 +490,8 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': [1], 'class_weights': [1],
'class_counts': None, 'class_counts': None,
'num_classes': None, 'num_classes': None,
'err_msg': "You must pass a value for num_classes if " 'err_msg': 'You must pass a value for num_classes if '
"creating an array of class_weights"}, 'creating an array of class_weights'},
{'testcase_name': 'class counts array, improper shape', {'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]], 'class_weights': [[1], [1]],
'class_counts': None, 'class_counts': None,
@ -500,14 +501,13 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': [1, 1, 1], 'class_weights': [1, 1, 1],
'class_counts': None, 'class_counts': None,
'num_classes': 2, 'num_classes': 2,
'err_msg': "Detected array length:"}, 'err_msg': 'Detected array length:'},
]) ])
def test_class_errors(self, def test_class_errors(self,
class_weights, class_weights,
class_counts, class_counts,
num_classes, num_classes,
err_msg err_msg):
):
"""Tests the BOltonModel calculate_class_weights method with invalid params """Tests the BOltonModel calculate_class_weights method with invalid params
which should raise the expected errors. which should raise the expected errors.
@ -521,8 +521,7 @@ class FitTests(keras_parameterized.TestCase):
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights, clf.calculate_class_weights(class_weights,
class_counts, class_counts,
num_classes num_classes)
)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -11,29 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unit testing for optimizers.py""" """Unit testing for optimizers."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import test
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras.initializers import constant
from tensorflow.python.keras import losses
from tensorflow.python.keras.models import Model
from tensorflow.python.framework import test_util
from tensorflow.python import ops as _ops
from absl.testing import parameterized from absl.testing import parameterized
from privacy.bolton.losses import StrongConvexMixin import tensorflow as tf
from tensorflow.python import ops as _ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.initializers import constant
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.platform import test
from privacy.bolton import optimizers as opt from privacy.bolton import optimizers as opt
from privacy.bolton.losses import StrongConvexMixin
class TestModel(Model): # pylint: disable=abstract-method class TestModel(Model): # pylint: disable=abstract-method
"""Bolton episilon-delta model. """Bolton episilon-delta model.
Uses 4 key steps to achieve privacy guarantees: Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation). 1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch 2. Projects weights to R after each batch
@ -68,7 +69,8 @@ class TestModel(Model): # pylint: disable=abstract-method
class TestLoss(losses.Loss, StrongConvexMixin): class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing Bolton model""" """Test loss function for testing Bolton model."""
def __init__(self, reg_lambda, C, radius_constant, name='test'): def __init__(self, reg_lambda, C, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name) super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
@ -77,6 +79,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
def radius(self): def radius(self):
"""Radius, R, of the hypothesis space W. """Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space. W is a convex set that forms the hypothesis space.
Returns: radius Returns: radius
@ -117,7 +120,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
) )
def max_class_weight(self, class_weight, dtype=tf.float32): def max_class_weight(self, class_weight, dtype=tf.float32):
"""the maximum weighting in class weights (max value) as a scalar tensor """the maximum weighting in class weights (max value) as a scalar tensor.
Args: Args:
class_weight: class weights used class_weight: class weights used
@ -141,6 +144,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
class TestOptimizer(OptimizerV2): class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the Bolton optimizer""" """Optimizer used for testing the Bolton optimizer"""
def __init__(self): def __init__(self):
super(TestOptimizer, self).__init__('test') super(TestOptimizer, self).__init__('test')
self.not_private = 'test' self.not_private = 'test'
@ -180,8 +184,9 @@ class TestOptimizer(OptimizerV2):
def limit_learning_rate(self): def limit_learning_rate(self):
return 'test' return 'test'
class BoltonOptimizerTest(keras_parameterized.TestCase): class BoltonOptimizerTest(keras_parameterized.TestCase):
"""Bolton Optimizer tests""" """Bolton Optimizer tests."""
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'getattr', {'testcase_name': 'getattr',
@ -195,6 +200,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'result': None, 'result': None,
'test_attr': ''}, 'test_attr': ''},
]) ])
def test_fn(self, fn, args, result, test_attr): def test_fn(self, fn, args, result, test_attr):
"""test that a fn of Bolton optimizer is working as expected. """test that a fn of Bolton optimizer is working as expected.
@ -294,7 +300,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'class_weights': 1}, 'class_weights': 1},
]) ])
def test_context_manager(self, noise, epsilon, class_weights): def test_context_manager(self, noise, epsilon, class_weights):
"""Tests the context manager functionality of the optimizer """Tests the context manager functionality of the optimizer.
Args: Args:
noise: noise distribution to pick noise: noise distribution to pick
@ -327,7 +333,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'}, 'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
]) ])
def test_context_domains(self, noise, epsilon, err_msg): def test_context_domains(self, noise, epsilon, err_msg):
""" """Tests the context domains.
Args: Args:
noise: noise distribution to pick noise: noise distribution to pick
@ -408,7 +414,9 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'args': [1, 1]}, 'args': [1, 1]},
]) ])
def test_rerouted_function(self, fn, args): def test_rerouted_function(self, fn, args):
""" tests that a method of the internal optimizer is correctly routed from """Tests rerouted function.
Tests that a method of the internal optimizer is correctly routed from
the Bolton instance to the internal optimizer instance (TestOptimizer, the Bolton instance to the internal optimizer instance (TestOptimizer,
here). here).
@ -495,15 +503,14 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
internal_optimizer = TestOptimizer() internal_optimizer = TestOptimizer()
optimizer = opt.Bolton(internal_optimizer, loss) optimizer = opt.Bolton(internal_optimizer, loss)
self.assertEqual(getattr(optimizer, attr), self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr) getattr(internal_optimizer, attr))
)
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'attr does not exist', {'testcase_name': 'attr does not exist',
'attr': '_not_valid'} 'attr': '_not_valid'}
]) ])
def test_attribute_error(self, attr): def test_attribute_error(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to """Test that attribute of internal optimizer is correctly rerouted to
the internal optimizer the internal optimizer
Args: Args:
@ -516,6 +523,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
getattr(optimizer, attr) getattr(optimizer, attr)
class SchedulerTest(keras_parameterized.TestCase): class SchedulerTest(keras_parameterized.TestCase):
"""GammaBeta Scheduler tests""" """GammaBeta Scheduler tests"""

View file

@ -11,14 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tutorial for bolton module, the model and the optimizer.""" """Tutorial for bolton module, the model and the optimizer."""
import sys
sys.path.append('..')
import tensorflow as tf # pylint: disable=wrong-import-position import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolton import losses # pylint: disable=wrong-import-position from privacy.bolton import losses # pylint: disable=wrong-import-position
from privacy.bolton import models # pylint: disable=wrong-import-position from privacy.bolton import models # pylint: disable=wrong-import-position
from privacy.bolton.optimizers import Bolton # pylint: disable=wrong-import-position
# ------- # -------
# First, we will create a binary classification dataset with a single output # First, we will create a binary classification dataset with a single output
# dimension. The samples for each label are repeated data points at different # dimension. The samples for each label are repeated data points at different
@ -59,9 +56,9 @@ loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
# For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy # For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy
# to be 1; these are all tunable and their impact can be read in losses. # to be 1; these are all tunable and their impact can be read in losses.
# StrongConvexBinaryCrossentropy.We then compile the model with the chosen # StrongConvexBinaryCrossentropy.We then compile the model with the chosen
# optimizer and loss, which will automatically wrap the chosen optimizer with the # optimizer and loss, which will automatically wrap the chosen optimizer with
# Bolton Optimizer, ensuring the required components function as required for # the Bolton Optimizer, ensuring the required components function as required
# privacy guarantees. # for privacy guarantees.
# ------- # -------
bolt.compile(optimizer, loss) bolt.compile(optimizer, loss)
# ------- # -------
@ -69,13 +66,13 @@ bolt.compile(optimizer, loss)
# the dataset and model.These parameters are: # the dataset and model.These parameters are:
# 1. the class_weights used # 1. the class_weights used
# 2. the number of samples in the dataset # 2. the number of samples in the dataset
# 3. the batch size which the model will try to infer, if possible. If not, you # 3. the batch size which the model will try to infer, if possible. If not,
# will be required to pass these explicitly to the fit method. # you will be required to pass these explicitly to the fit method.
# #
# As well, there are two privacy parameters than can be altered: # As well, there are two privacy parameters than can be altered:
# 1. epsilon, a float # 1. epsilon, a float
# 2. noise_distribution, a valid string indicating the distriution to use (must be # 2. noise_distribution, a valid string indicating the distriution to use (must
# implemented) # be implemented)
# #
# The BoltonModel offers a helper method,.calculate_class_weight to aid in # The BoltonModel offers a helper method,.calculate_class_weight to aid in
# class_weight calculation. # class_weight calculation.
@ -117,8 +114,7 @@ try:
batch_size=batch_size, batch_size=batch_size,
n_samples=n_samples, n_samples=n_samples,
noise_distribution=noise_distribution, noise_distribution=noise_distribution,
verbose=0 verbose=0)
)
except ValueError as e: except ValueError as e:
print(e) print(e)
# ------- # -------
@ -131,8 +127,7 @@ bolt.fit(generator,
batch_size=batch_size, batch_size=batch_size,
n_samples=n_samples, n_samples=n_samples,
noise_distribution=noise_distribution, noise_distribution=noise_distribution,
verbose=0 verbose=0)
)
# ------- # -------
# You don't have to use the bolton model to use the Bolton method. # You don't have to use the bolton model to use the Bolton method.
# There are only a few requirements: # There are only a few requirements:
@ -140,16 +135,18 @@ bolt.fit(generator,
# 2. instantiate the optimizer and use it as a context around the fit operation. # 2. instantiate the optimizer and use it as a context around the fit operation.
# ------- # -------
# -------------------- Part 2, using the Optimizer # -------------------- Part 2, using the Optimizer
from privacy.bolton.optimizers import Bolton # pylint: disable=wrong-import-position
# ------- # -------
# Here, we create our own model and setup the Bolton optimizer. # Here, we create our own model and setup the Bolton optimizer.
# ------- # -------
class TestModel(tf.keras.Model): # pylint: disable=abstract-method class TestModel(tf.keras.Model): # pylint: disable=abstract-method
def __init__(self, reg_layer, number_of_outputs=1): def __init__(self, reg_layer, number_of_outputs=1):
super(TestModel, self).__init__(name='test') super(TestModel, self).__init__(name='test')
self.output_layer = tf.keras.layers.Dense(number_of_outputs, self.output_layer = tf.keras.layers.Dense(number_of_outputs,
kernel_regularizer=reg_layer kernel_regularizer=reg_layer)
)
def call(self, inputs): # pylint: disable=arguments-differ def call(self, inputs): # pylint: disable=arguments-differ
return self.output_layer(inputs) return self.output_layer(inputs)