diff --git a/privacy/__init__.py b/privacy/__init__.py index 59bfe20..e494c62 100644 --- a/privacy/__init__.py +++ b/privacy/__init__.py @@ -41,3 +41,9 @@ else: from privacy.optimizers.dp_optimizer import DPAdamOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer + + from privacy.bolton.models import BoltonModel + from privacy.bolton.optimizers import Bolton + from privacy.bolton.losses import StrongConvexMixin + from privacy.bolton.losses import StrongConvexBinaryCrossentropy + from privacy.bolton.losses import StrongConvexHuber diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py new file mode 100644 index 0000000..971b804 --- /dev/null +++ b/privacy/bolton/__init__.py @@ -0,0 +1,15 @@ +import sys +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + raise ImportError("Please upgrade your version of tensorflow from: {0} " + "to at least 2.0.0 to use privacy/bolton".format( + LooseVersion(tf.__version__))) +if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. + pass +else: + from privacy.bolton.models import BoltonModel + from privacy.bolton.optimizers import Bolton + from privacy.bolton.losses import StrongConvexHuber + from privacy.bolton.losses import StrongConvexBinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/losses.py b/privacy/bolton/losses.py new file mode 100644 index 0000000..a99187b --- /dev/null +++ b/privacy/bolton/losses.py @@ -0,0 +1,319 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras import losses +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.platform import tf_logging as logging + + +class StrongConvexMixin: + """ + Strong Convex Mixin base class for any loss function that will be used with + Bolton model. Subclasses must be strongly convex and implement the + associated constants. They must also conform to the requirements of tf losses + (see super class). + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. + """ + + def radius(self): + """Radius, R, of the hypothesis space W. + W is a convex set that forms the hypothesis space. + + Returns: R + + """ + raise NotImplementedError("Radius not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def gamma(self): + """ Strongly convexity, gamma + + Returns: gamma + + """ + raise NotImplementedError("Gamma not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def beta(self, class_weight): + """Smoothness, beta + + Args: + class_weight: the class weights as scalar or 1d tensor, where its + dimensionality is equal to the number of outputs. + + Returns: Beta + + """ + raise NotImplementedError("Beta not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def lipchitz_constant(self, class_weight): + """Lipchitz constant, L + + Args: + class_weight: class weights used + + Returns: L + + """ + raise NotImplementedError("lipchitz constant not implemented for " + "StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def kernel_regularizer(self): + """returns the kernel_regularizer to be used. Any subclass should override + this method if they want a kernel_regularizer (if required for + the loss function to be StronglyConvex + + :return: None or kernel_regularizer layer + """ + return None + + def max_class_weight(self, class_weight, dtype): + """the maximum weighting in class weights (max value) as a scalar tensor + + Args: + class_weight: class weights used + dtype: the data type for tensor conversions. + + Returns: maximum class weighting as tensor scalar + + """ + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype) + return tf.math.reduce_max(class_weight) + + +class StrongConvexHuber(losses.Loss, StrongConvexMixin): + """Strong Convex version of Huber loss using l2 weight regularization. + """ + + def __init__(self, + reg_lambda: float, + C: float, + radius_constant: float, + delta: float, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + dtype=tf.float32): + """Constructor. + + Args: + reg_lambda: Weight regularization constant + C: Penalty parameter C of the loss term + radius_constant: constant defining the length of the radius + delta: delta value in huber loss. When to switch from quadratic to + absolute deviation. + reduction: reduction type to use. See super class + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + + Returns: + Loss values per sample. + """ + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + if delta <= 0: + raise ValueError('delta: {0}, should be >= 0'.format( + delta + )) + self.C = C # pylint: disable=invalid-name + self.delta = delta + self.radius_constant = radius_constant + self.dtype = dtype + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexHuber, self).__init__( + name='strongconvexhuber', + reduction=reduction, + ) + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. One hot encoded using -1 and 1. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + # return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight + h = self.delta + z = y_pred * y_true + one = tf.constant(1, dtype=self.dtype) + four = tf.constant(4, dtype=self.dtype) + + if z > one + h: + return _ops.convert_to_tensor_v2(0, dtype=self.dtype) + elif tf.math.abs(one - z) <= h: + return one / (four * h) * tf.math.pow(one + h - z, 2) + elif z < one - h: + return one - z + raise ValueError('') # shouldn't be possible to get here. + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda + + def gamma(self): + """See super class. + """ + return self.reg_lambda + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight, self.dtype) + delta = _ops.convert_to_tensor_v2(self.delta, + dtype=self.dtype + ) + return self.C * max_class_weight / (delta * + tf.constant(2, dtype=self.dtype)) + \ + self.reg_lambda + + def lipchitz_constant(self, class_weight): + """See super class. + """ + # if class_weight is provided, + # it should be a vector of the same size of number of classes + max_class_weight = self.max_class_weight(class_weight, self.dtype) + lc = self.C * max_class_weight + \ + self.reg_lambda * self.radius() + return lc + + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda/2) + + +class StrongConvexBinaryCrossentropy( + losses.BinaryCrossentropy, + StrongConvexMixin +): + """ + Strong Convex version of BinaryCrossentropy loss using l2 weight + regularization. + """ + + def __init__(self, + reg_lambda: float, + C: float, + radius_constant: float, + from_logits: bool = True, + label_smoothing: float = 0, + reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + dtype=tf.float32): + """ + Args: + reg_lambda: Weight regularization constant + C: Penalty parameter C of the loss term + radius_constant: constant defining the length of the radius + reduction: reduction type to use. See super class + label_smoothing: amount of smoothing to perform on labels + relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). + Note, the impact of this parameter's effect on privacy + is not known and thus the default should be used. + name: Name of the loss instance + dtype: tf datatype to use for tensor conversions. + """ + if label_smoothing != 0: + logging.warning('The impact of label smoothing on privacy is unknown. ' + 'Use label smoothing at your own risk as it may not ' + 'guarantee privacy.') + + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + self.dtype = dtype + self.C = C # pylint: disable=invalid-name + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexBinaryCrossentropy, self).__init__( + reduction=reduction, + name='strongconvexbinarycrossentropy', + from_logits=from_logits, + label_smoothing=label_smoothing, + ) + self.radius_constant = radius_constant + + def call(self, y_true, y_pred): + """Compute loss + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + # loss = tf.nn.sigmoid_cross_entropy_with_logits( + # labels=y_true, + # logits=y_pred + # ) + loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred) + loss = loss * self.C + return loss + + def radius(self): + """See super class. + """ + return self.radius_constant / self.reg_lambda + + def gamma(self): + """See super class. + """ + return self.reg_lambda + + def beta(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda + + def lipchitz_constant(self, class_weight): + """See super class. + """ + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda * self.radius() + + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda/2) diff --git a/privacy/bolton/losses_test.py b/privacy/bolton/losses_test.py new file mode 100644 index 0000000..d2c9f80 --- /dev/null +++ b/privacy/bolton/losses_test.py @@ -0,0 +1,381 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for losses.py""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.framework import test_util +from tensorflow.python.keras.regularizers import L1L2 +from absl.testing import parameterized +from privacy.bolton.losses import StrongConvexBinaryCrossentropy +from privacy.bolton.losses import StrongConvexHuber +from privacy.bolton.losses import StrongConvexMixin + + +class StrongConvexMixinTests(keras_parameterized.TestCase): + """Tests for the StrongConvexMixin""" + @parameterized.named_parameters([ + {'testcase_name': 'beta not implemented', + 'fn': 'beta', + 'args': [1]}, + {'testcase_name': 'gamma not implemented', + 'fn': 'gamma', + 'args': []}, + {'testcase_name': 'lipchitz not implemented', + 'fn': 'lipchitz_constant', + 'args': [1]}, + {'testcase_name': 'radius not implemented', + 'fn': 'radius', + 'args': []}, + ]) + def test_not_implemented(self, fn, args): + """Test that the given fn's are not implemented on the mixin. + + Args: + fn: fn on Mixin to test + args: arguments to fn of Mixin + """ + with self.assertRaises(NotImplementedError): + loss = StrongConvexMixin() + getattr(loss, fn, None)(*args) + + @parameterized.named_parameters([ + {'testcase_name': 'radius not implemented', + 'fn': 'kernel_regularizer', + 'args': []}, + ]) + def test_return_none(self, fn, args): + """Test that fn of Mixin returns None + + Args: + fn: fn of Mixin to test + args: arguments to fn of Mixin + """ + loss = StrongConvexMixin() + ret = getattr(loss, fn, None)(*args) + self.assertEqual(ret, None) + + +class BinaryCrossesntropyTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'C': 1, + 'radius_constant': 1 + }, # pylint: disable=invalid-name + ]) + def test_init_params(self, reg_lambda, C, radius_constant): + """Test initialization for given arguments + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) + self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'C': -1, + 'radius_constant': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'C': 1, + 'radius_constant': -1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'C': 1, + 'radius_constant': 1 + }, # pylint: disable=invalid-name + ]) + def test_bad_init_params(self, reg_lambda, C, radius_constant): + """Test invalid domain for given params. Should return ValueError + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + with self.assertRaises(ValueError): + StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) + + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + # [] for compatibility with tensorflow loss calculation + {'testcase_name': 'both positive', + 'logits': [10000], + 'y_true': [1], + 'result': 0, + }, + {'testcase_name': 'positive gradient negative logits', + 'logits': [-10000], + 'y_true': [1], + 'result': 10000, + }, + {'testcase_name': 'positivee gradient positive logits', + 'logits': [10000], + 'y_true': [0], + 'result': 10000, + }, + {'testcase_name': 'both negative', + 'logits': [-10000], + 'y_true': [0], + 'result': 0 + }, + ]) + def test_calculation(self, logits, y_true, result): + """Test the call method to ensure it returns the correct value + Args: + logits: unscaled output of model + y_true: label + result: correct loss calculation value + """ + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) + loss = loss(y_true, logits) + self.assertEqual(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1], + 'args': [], + 'result': tf.constant(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1], + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1], + 'args': [], + 'result': L1L2(l2=0.5), + }, + ]) + def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ + loss = StrongConvexBinaryCrossentropy(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +class HuberTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1, + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant, delta): + """Test initialization for given arguments + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + self.assertIsInstance(loss, StrongConvexHuber) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1, + 'delta': 1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative delta', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': -1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): + """Test invalid domain for given params. Should return ValueError + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + with self.assertRaises(ValueError): + StrongConvexHuber(reg_lambda, c, radius_constant, delta) + + # test the bounds and test varied delta's + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', + 'logits': 2.1, + 'y_true': 1, + 'delta': 1, + 'result': 0, + }, + {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', + 'logits': 1.9, + 'y_true': 1, + 'delta': 1, + 'result': 0.01*0.25, + }, + {'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary', + 'logits': 0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.9**2 * 0.25, + }, + {'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary', + 'logits': -0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.1, + }, + {'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary', + 'logits': 3.1, + 'y_true': 1, + 'delta': 2, + 'result': 0, + }, + {'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary', + 'logits': 2.9, + 'y_true': 1, + 'delta': 2, + 'result': 0.01*0.125, + }, + {'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary', + 'logits': 1.1, + 'y_true': 1, + 'delta': 2, + 'result': 1.9**2 * 0.125, + }, + {'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary', + 'logits': -1.1, + 'y_true': 1, + 'delta': 2, + 'result': 2.1, + }, + {'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary', + 'logits': -2.1, + 'y_true': -1, + 'delta': 1, + 'result': 0, + }, + ]) + def test_calculation(self, logits, y_true, delta, result): + """Test the call method to ensure it returns the correct value + Args: + logits: unscaled output of model + y_true: label + result: correct loss calculation value + """ + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexHuber(0.00001, 1, 1, delta) + loss = loss(y_true, logits) + self.assertAllClose(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.Variable(1.5, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': tf.Variable(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1, 1], + 'args': [1], + 'result': tf.Variable(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': L1L2(l2=0.5), + }, + ]) + def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ + loss = StrongConvexHuber(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +if __name__ == '__main__': + tf.test.main() diff --git a/privacy/bolton/models.py b/privacy/bolton/models.py new file mode 100644 index 0000000..7503157 --- /dev/null +++ b/privacy/bolton/models.py @@ -0,0 +1,302 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bolton model for bolton method of differentially private ML""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.keras.models import Model +from tensorflow.python.keras import optimizers +from tensorflow.python.framework import ops as _ops +from privacy.bolton.losses import StrongConvexMixin +from privacy.bolton.optimizers import Bolton + + +class BoltonModel(Model): + """ + Bolton episilon-delta model + Uses 4 key steps to achieve privacy guarantees: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. + """ + + def __init__(self, + n_outputs, + seed=1, + dtype=tf.float32 + ): + """ private constructor. + + Args: + n_outputs: number of output classes to predict. + seed: random seed to use + dtype: data type to use for tensors + """ + super(BoltonModel, self).__init__(name='bolton', dynamic=False) + if n_outputs <= 0: + raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format( + n_outputs + )) + self.n_outputs = n_outputs + self.seed = seed + self._layers_instantiated = False + self._dtype = dtype + + def call(self, inputs): # pylint: disable=arguments-differ + """Forward pass of network + + Args: + inputs: inputs to neural network + + Returns: + + """ + return self.output_layer(inputs) + + def compile(self, + optimizer, + loss, + metrics=None, + loss_weights=None, + sample_weight_mode=None, + weighted_metrics=None, + target_tensors=None, + distribute=None, + kernel_initializer=tf.initializers.GlorotUniform, + **kwargs): # pylint: disable=arguments-differ + """See super class. Default optimizer used in Bolton method is SGD. + + """ + if not isinstance(loss, StrongConvexMixin): + raise ValueError("loss function must be a Strongly Convex and therefore " + "extend the StrongConvexMixin.") + if not self._layers_instantiated: # compile may be called multiple times + # for instance, if the input/outputs are not defined until fit. + self.output_layer = tf.keras.layers.Dense( + self.n_outputs, + kernel_regularizer=loss.kernel_regularizer(), + kernel_initializer=kernel_initializer(), + ) + self._layers_instantiated = True + if not isinstance(optimizer, Bolton): + optimizer = optimizers.get(optimizer) + optimizer = Bolton(optimizer, loss) + + super(BoltonModel, self).compile(optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode, + weighted_metrics=weighted_metrics, + target_tensors=target_tensors, + distribute=distribute, + **kwargs + ) + + def fit(self, + x=None, + y=None, + batch_size=None, + class_weight=None, + n_samples=None, + epsilon=2, + noise_distribution='laplace', + steps_per_epoch=None, + **kwargs): # pylint: disable=arguments-differ + """Reroutes to super fit with additional Bolton delta-epsilon privacy + requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1 + Requirements are as follows: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + See super implementation for more details. + + Args: + n_samples: the number of individual samples in x. + epsilon: privacy parameter, which trades off between utility an privacy. + See the bolton paper for more description. + noise_distribution: the distribution to pull noise from. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. + + See the super method for descriptions on the rest of the arguments. + + """ + if class_weight is None: + class_weight_ = self.calculate_class_weights(class_weight) + else: + class_weight_ = class_weight + if n_samples is not None: + data_size = n_samples + elif hasattr(x, 'shape'): + data_size = x.shape[0] + elif hasattr(x, "__len__"): + data_size = len(x) + else: + data_size = None + batch_size_ = self._validate_or_infer_batch_size(batch_size, + steps_per_epoch, + x + ) + # inferring batch_size to be passed to optimizer. batch_size must remain its + # initial value when passed to super().fit() + if batch_size_ is None: + raise ValueError('batch_size: {0} is an ' + 'invalid value'.format(batch_size_)) + if data_size is None: + raise ValueError('Could not infer the number of samples. Please pass ' + 'this in using n_samples.') + with self.optimizer(noise_distribution, + epsilon, + self.layers, + class_weight_, + data_size, + batch_size_, + ) as _: + out = super(BoltonModel, self).fit(x=x, + y=y, + batch_size=batch_size, + class_weight=class_weight, + steps_per_epoch=steps_per_epoch, + **kwargs + ) + return out + + def fit_generator(self, + generator, + class_weight=None, + noise_distribution='laplace', + epsilon=2, + n_samples=None, + steps_per_epoch=None, + **kwargs + ): # pylint: disable=arguments-differ + """ + This method is the same as fit except for when the passed dataset + is a generator. See super method and fit for more details. + Args: + n_samples: number of individual samples in x + noise_distribution: the distribution to get noise from. + epsilon: privacy parameter, which trades off utility and privacy. See + Bolton paper for more description. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. + + See the super method for descriptions on the rest of the arguments. + """ + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + if n_samples is not None: + data_size = n_samples + elif hasattr(generator, 'shape'): + data_size = generator.shape[0] + elif hasattr(generator, "__len__"): + data_size = len(generator) + else: + data_size = None + batch_size = self._validate_or_infer_batch_size(None, + steps_per_epoch, + generator + ) + with self.optimizer(noise_distribution, + epsilon, + self.layers, + class_weight, + data_size, + batch_size + ) as _: + out = super(BoltonModel, self).fit_generator( + generator, + class_weight=class_weight, + steps_per_epoch=steps_per_epoch, + **kwargs + ) + return out + + def calculate_class_weights(self, + class_weights=None, + class_counts=None, + num_classes=None + ): + """ + Calculates class weighting to be used in training. Can be on + Args: + class_weights: str specifying type, array giving weights, or None. + class_counts: If class_weights is not None, then an array of + the number of samples for each class + num_classes: If class_weights is not None, then the number of + classes. + Returns: class_weights as 1D tensor, to be passed to model's fit method. + + """ + # Value checking + class_keys = ['balanced'] + is_string = False + if isinstance(class_weights, str): + is_string = True + if class_weights not in class_keys: + raise ValueError("Detected string class_weights with " + "value: {0}, which is not one of {1}." + "Please select a valid class_weight type" + "or pass an array".format(class_weights, + class_keys)) + if class_counts is None: + raise ValueError("Class counts must be provided if using " + "class_weights=%s" % class_weights) + class_counts_shape = tf.Variable(class_counts, + trainable=False, + dtype=self._dtype).shape + if len(class_counts_shape) != 1: + raise ValueError('class counts must be a 1D array.' + 'Detected: {0}'.format(class_counts_shape)) + if num_classes is None: + raise ValueError("num_classes must be provided if using " + "class_weights=%s" % class_weights) + elif class_weights is not None: + if num_classes is None: + raise ValueError("You must pass a value for num_classes if " + "creating an array of class_weights") + # performing class weight calculation + if class_weights is None: + class_weights = 1 + elif is_string and class_weights == 'balanced': + num_samples = sum(class_counts) + weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes, + class_counts, + ), + self._dtype + ) + class_weights = tf.Variable(num_samples, dtype=self._dtype) / \ + tf.Variable(weighted_counts, dtype=self._dtype) + else: + class_weights = _ops.convert_to_tensor_v2(class_weights) + if len(class_weights.shape) != 1: + raise ValueError("Detected class_weights shape: {0} instead of " + "1D array".format(class_weights.shape)) + if class_weights.shape[0] != num_classes: + raise ValueError( + "Detected array length: {0} instead of: {1}".format( + class_weights.shape[0], + num_classes + ) + ) + return class_weights diff --git a/privacy/bolton/models_test.py b/privacy/bolton/models_test.py new file mode 100644 index 0000000..63954cc --- /dev/null +++ b/privacy/bolton/models_test.py @@ -0,0 +1,520 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for models.py""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import tensorflow as tf +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 +from tensorflow.python.keras import losses +from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras.regularizers import L1L2 +from absl.testing import parameterized +from privacy.bolton import models +from privacy.bolton.optimizers import Bolton +from privacy.bolton.losses import StrongConvexMixin + +class TestLoss(losses.Loss, StrongConvexMixin): + """Test loss function for testing Bolton model""" + def __init__(self, reg_lambda, C, radius_constant, name='test'): + super(TestLoss, self).__init__(name=name) + self.reg_lambda = reg_lambda + self.C = C # pylint: disable=invalid-name + self.radius_constant = radius_constant + + def radius(self): + """Radius of R-Ball (value to normalize weights to after each batch) + + Returns: radius + + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def gamma(self): + """ Gamma strongly convex + + Returns: gamma + + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def beta(self, class_weight): # pylint: disable=unused-argument + """Beta smoothess + + Args: + class_weight: the class weights used. + + Returns: Beta + + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument + """ L lipchitz continuous + + Args: + class_weight: class weights used + + Returns: L + + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def call(self, y_true, y_pred): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum( + tf.math.squared_difference(y_true, y_pred), + axis=1 + ) + + def max_class_weight(self, class_weight): + if class_weight is None: + return 1 + raise ValueError('') + + def kernel_regularizer(self): + return L1L2(l2=self.reg_lambda) + + +class TestOptimizer(OptimizerV2): + """Test optimizer used for testing Bolton model""" + def __init__(self): + super(TestOptimizer, self).__init__('test') + + def compute_gradients(self): + return 0 + + def get_config(self): + return {} + + def _create_slots(self, var): + pass + + def _resource_apply_dense(self, grad, handle): + return grad + + def _resource_apply_sparse(self, grad, handle, indices): + return grad + + +class InitTests(keras_parameterized.TestCase): + """tests for keras model initialization""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'n_outputs': 1, + }, + {'testcase_name': 'many outputs', + 'n_outputs': 100, + }, + ]) + def test_init_params(self, n_outputs): + """test initialization of BoltonModel + + Args: + n_outputs: number of output neurons + """ + # test valid domains for each variable + clf = models.BoltonModel(n_outputs) + self.assertIsInstance(clf, models.BoltonModel) + + @parameterized.named_parameters([ + {'testcase_name': 'invalid n_outputs', + 'n_outputs': -1, + }, + ]) + def test_bad_init_params(self, n_outputs): + """test bad initializations of BoltonModel that should raise errors + + Args: + n_outputs: number of output neurons + """ + # test invalid domains for each variable, especially noise + with self.assertRaises(ValueError): + models.BoltonModel(n_outputs) + + @parameterized.named_parameters([ + {'testcase_name': 'string compile', + 'n_outputs': 1, + 'loss': TestLoss(1, 1, 1), + 'optimizer': 'adam', + }, + {'testcase_name': 'test compile', + 'n_outputs': 100, + 'loss': TestLoss(1, 1, 1), + 'optimizer': TestOptimizer(), + }, + ]) + def test_compile(self, n_outputs, loss, optimizer): + """test compilation of BoltonModel + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instanced TestOptimizer instance + """ + # test compilation of valid tf.optimizer and tf.loss + with self.cached_session(): + clf = models.BoltonModel(n_outputs) + clf.compile(optimizer, loss) + self.assertEqual(clf.loss, loss) + + @parameterized.named_parameters([ + {'testcase_name': 'Not strong loss', + 'n_outputs': 1, + 'loss': losses.BinaryCrossentropy(), + 'optimizer': 'adam', + }, + {'testcase_name': 'Not valid optimizer', + 'n_outputs': 1, + 'loss': TestLoss(1, 1, 1), + 'optimizer': 'ada', + } + ]) + def test_bad_compile(self, n_outputs, loss, optimizer): + """test bad compilations of BoltonModel that should raise errors + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instanced TestOptimizer instance + """ + # test compilaton of invalid tf.optimizer and non instantiated loss. + with self.cached_session(): + with self.assertRaises((ValueError, AttributeError)): + clf = models.BoltonModel(n_outputs) + clf.compile(optimizer, loss) + + +def _cat_dataset(n_samples, input_dim, n_classes, generator=False): + """ + Creates a categorically encoded dataset (y is categorical). + returns the specified dataset either as a static array or as a generator. + Will have evenly split samples across each output class. + Each output class will be a different point in the input space. + + Args: + n_samples: number of rows + input_dim: input dimensionality + n_classes: output dimensionality + generator: False for array, True for generator + Returns: + X as (n_samples, input_dim), Y as (n_samples, n_outputs) + """ + x_stack = [] + y_stack = [] + for i_class in range(n_classes): + x_stack.append( + tf.constant(1*i_class, tf.float32, (n_samples, input_dim)) + ) + y_stack.append( + tf.constant(i_class, tf.float32, (n_samples, n_classes)) + ) + x_set, y_set = tf.stack(x_stack), tf.stack(y_stack) + if generator: + dataset = tf.data.Dataset.from_tensor_slices( + (x_set, y_set) + ) + return dataset + return x_set, y_set + +def _do_fit(n_samples, + input_dim, + n_outputs, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + distribution='laplace'): + """Helper to instantiate necessary components for fitting and perform a model + fit. + + Args: + n_samples: number of samples in dataset + input_dim: the sample dimensionality + n_outputs: number of output neurons + epsilon: privacy parameter + generator: True to create a generator, False to use an iterator + batch_size: batch_size to use + reset_n_samples: True to set _samples to None prior to fitting. + False does nothing + optimizer: instance of TestOptimizer + loss: instance of TestLoss + distribution: distribution to get noise from. + + Returns: BoltonModel instsance + """ + clf = models.BoltonModel(n_outputs) + clf.compile(optimizer, loss) + if generator: + x = _cat_dataset( + n_samples, + input_dim, + n_outputs, + generator=generator + ) + y = None + # x = x.batch(batch_size) + x = x.shuffle(n_samples//2) + batch_size = None + else: + x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator) + if reset_n_samples: + n_samples = None + + clf.fit(x, + y, + batch_size=batch_size, + n_samples=n_samples, + noise_distribution=distribution, + epsilon=epsilon + ) + return clf + + +class FitTests(keras_parameterized.TestCase): + """Test cases for keras model fitting""" + + # @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'iterator fit', + 'generator': False, + 'reset_n_samples': True, + }, + {'testcase_name': 'iterator fit no samples', + 'generator': False, + 'reset_n_samples': True, + }, + {'testcase_name': 'generator fit', + 'generator': True, + 'reset_n_samples': False, + }, + {'testcase_name': 'with callbacks', + 'generator': True, + 'reset_n_samples': False, + }, + ]) + def test_fit(self, generator, reset_n_samples): + """Tests fitting of BoltonModel + + Args: + generator: True for generator test, False for iterator test. + reset_n_samples: True to reset the n_samples to None, False does nothing + """ + loss = TestLoss(1, 1, 1) + optimizer = Bolton(TestOptimizer(), loss) + n_classes = 2 + input_dim = 5 + epsilon = 1 + batch_size = 1 + n_samples = 10 + clf = _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + ) + self.assertEqual(hasattr(clf, 'layers'), True) + + @parameterized.named_parameters([ + {'testcase_name': 'generator fit', + 'generator': True, + }, + ]) + def test_fit_gen(self, generator): + """Tests the fit_generator method of BoltonModel + + Args: + generator: True to test with a generator dataset + """ + loss = TestLoss(1, 1, 1) + optimizer = TestOptimizer() + n_classes = 2 + input_dim = 5 + batch_size = 1 + n_samples = 10 + clf = models.BoltonModel(n_classes) + clf.compile(optimizer, loss) + x = _cat_dataset( + n_samples, + input_dim, + n_classes, + generator=generator + ) + x = x.batch(batch_size) + x = x.shuffle(n_samples // 2) + clf.fit_generator(x, n_samples=n_samples) + self.assertEqual(hasattr(clf, 'layers'), True) + + @parameterized.named_parameters([ + {'testcase_name': 'iterator no n_samples', + 'generator': True, + 'reset_n_samples': True, + 'distribution': 'laplace' + }, + {'testcase_name': 'invalid distribution', + 'generator': True, + 'reset_n_samples': True, + 'distribution': 'not_valid' + }, + ]) + def test_bad_fit(self, generator, reset_n_samples, distribution): + """Tests fitting with invalid parameters, which should raise an error + + Args: + generator: True to test with generator, False is iterator + reset_n_samples: True to reset the n_samples param to None prior to + passing it to fit + distribution: distribution to get noise from. + """ + with self.assertRaises(ValueError): + loss = TestLoss(1, 1, 1) + optimizer = TestOptimizer() + n_classes = 2 + input_dim = 5 + epsilon = 1 + batch_size = 1 + n_samples = 10 + _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + distribution + ) + + @parameterized.named_parameters([ + {'testcase_name': 'None class_weights', + 'class_weights': None, + 'class_counts': None, + 'num_classes': None, + 'result': 1}, + {'testcase_name': 'class weights array', + 'class_weights': [1, 1], + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, + {'testcase_name': 'class weights balanced', + 'class_weights': 'balanced', + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, + ]) + def test_class_calculate(self, + class_weights, + class_counts, + num_classes, + result + ): + """Tests the BOltonModel calculate_class_weights method + + Args: + class_weights: the class_weights to use + class_counts: count of number of samples for each class + num_classes: number of outputs neurons + result: expected result + """ + clf = models.BoltonModel(1, 1) + expected = clf.calculate_class_weights(class_weights, + class_counts, + num_classes + ) + + if hasattr(expected, 'numpy'): + expected = expected.numpy() + self.assertAllEqual( + expected, + result + ) + @parameterized.named_parameters([ + {'testcase_name': 'class weight not valid str', + 'class_weights': 'not_valid', + 'class_counts': 1, + 'num_classes': 1, + 'err_msg': "Detected string class_weights with value: not_valid"}, + {'testcase_name': 'no class counts', + 'class_weights': 'balanced', + 'class_counts': None, + 'num_classes': 1, + 'err_msg': "Class counts must be provided if " + "using class_weights=balanced"}, + {'testcase_name': 'no num classes', + 'class_weights': 'balanced', + 'class_counts': [1], + 'num_classes': None, + 'err_msg': 'num_classes must be provided if ' + 'using class_weights=balanced'}, + {'testcase_name': 'class counts not array', + 'class_weights': 'balanced', + 'class_counts': 1, + 'num_classes': None, + 'err_msg': 'class counts must be a 1D array.'}, + {'testcase_name': 'class counts array, no num classes', + 'class_weights': [1], + 'class_counts': None, + 'num_classes': None, + 'err_msg': "You must pass a value for num_classes if " + "creating an array of class_weights"}, + {'testcase_name': 'class counts array, improper shape', + 'class_weights': [[1], [1]], + 'class_counts': None, + 'num_classes': 2, + 'err_msg': "Detected class_weights shape"}, + {'testcase_name': 'class counts array, wrong number classes', + 'class_weights': [1, 1, 1], + 'class_counts': None, + 'num_classes': 2, + 'err_msg': "Detected array length:"}, + ]) + def test_class_errors(self, + class_weights, + class_counts, + num_classes, + err_msg + ): + """Tests the BOltonModel calculate_class_weights method with invalid params + which should raise the expected errors. + + Args: + class_weights: the class_weights to use + class_counts: count of number of samples for each class + num_classes: number of outputs neurons + result: expected result + """ + clf = models.BoltonModel(1, 1) + with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method + clf.calculate_class_weights(class_weights, + class_counts, + num_classes + ) + + +if __name__ == '__main__': + tf.test.main() diff --git a/privacy/bolton/optimizers.py b/privacy/bolton/optimizers.py new file mode 100644 index 0000000..ec7a7e5 --- /dev/null +++ b/privacy/bolton/optimizers.py @@ -0,0 +1,391 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Bolton Optimizer for bolton method""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import math_ops +from privacy.bolton.losses import StrongConvexMixin + +_accepted_distributions = ['laplace'] # implemented distributions for noising + + +class GammaBetaDecreasingStep( + optimizer_v2.learning_rate_schedule.LearningRateSchedule +): + """ + Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step) + at each step. A required step for privacy guarantees. + """ + def __init__(self): + self.is_init = False + self.beta = None + self.gamma = None + + def __call__(self, step): + """ + returns the learning rate + Args: + step: the current iteration number + Returns: + decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per + the Bolton privacy requirements. + """ + if not self.is_init: + raise AttributeError('Please initialize the {0} Learning Rate Scheduler.' + 'This is performed automatically by using the ' + '{1} as a context manager, ' + 'as desired'.format(self.__class__.__name__, + Bolton.__class__.__name__ + ) + ) + dtype = self.beta.dtype + one = tf.constant(1, dtype) + return tf.math.minimum(tf.math.reduce_min(one/self.beta), + one/(self.gamma*math_ops.cast(step, dtype)) + ) + + def get_config(self): + """ + config to setup the learning rate scheduler. + """ + return {'beta': self.beta, 'gamma': self.gamma} + + def initialize(self, beta, gamma): + """setup the learning rate scheduler with the beta and gamma values provided + by the loss function. Meant to be used with .fit as the loss params may + depend on values passed to fit. + + Args: + beta: Smoothness value. See StrongConvexMixin + gamma: Strong Convexity parameter. See StrongConvexMixin. + """ + self.is_init = True + self.beta = beta + self.gamma = gamma + + def de_initialize(self): + """De initialize the scheduler after fitting, in case another fit call has + different loss parameters. + """ + self.is_init = False + self.beta = None + self.gamma = None + + +class Bolton(optimizer_v2.OptimizerV2): + """ + Bolton optimizer wraps another tf optimizer to be used + as the visible optimizer to the tf model. No matter the optimizer + passed, "Bolton" enables the bolton model to control the learning rate + based on the strongly convex loss. + + To use the Bolton method, you must: + 1. instantiate it with an instantiated tf optimizer and StrongConvexLoss. + 2. use it as a context manager around your .fit method internals. + + This can be accomplished by the following: + optimizer = tf.optimizers.SGD() + loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy() + bolton = Bolton(optimizer, loss) + with bolton(*args) as _: + model.fit() + The args required for the context manager can be found in the __call__ + method. + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. + """ + def __init__(self, # pylint: disable=super-init-not-called + optimizer: optimizer_v2.OptimizerV2, + loss: StrongConvexMixin, + dtype=tf.float32, + ): + """Constructor. + + Args: + optimizer: Optimizer_v2 or subclass to be used as the optimizer + (wrapped). + loss: StrongConvexLoss function that the model is being compiled with. + """ + + if not isinstance(loss, StrongConvexMixin): + raise ValueError("loss function must be a Strongly Convex and therefore " + "extend the StrongConvexMixin.") + self._private_attributes = ['_internal_optimizer', + 'dtype', + 'noise_distribution', + 'epsilon', + 'loss', + 'class_weights', + 'input_dim', + 'n_samples', + 'layers', + 'batch_size', + '_is_init' + ] + self._internal_optimizer = optimizer + self.learning_rate = GammaBetaDecreasingStep() # use the Bolton Learning + # rate scheduler, as required for privacy guarantees. This will still need + # to get values from the loss function near the time that .fit is called + # on the model (when this optimizer will be called as a context manager) + self.dtype = dtype + self.loss = loss + self._is_init = False + + def get_config(self): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_config() + + def project_weights_to_r(self, force=False): + """helper method to normalize the weights to the R-ball. + + Args: + force: True to normalize regardless of previous weight values. + False to check if weights > R-ball and only normalize then. + + Returns: + + """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') + radius = self.loss.radius() + for layer in self.layers: + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / radius) + else: + layer.kernel = tf.cond( + tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0, + lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop + lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop + ) + + def get_noise(self, input_dim, output_dim): + """Sample noise to be added to weights for privacy guarantee + + Args: + input_dim: the input dimensionality for the weights + output_dim the output dimensionality for the weights + + Returns: noise in shape of layer's weights to be added to the weights. + + """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') + loss = self.loss + distribution = self.noise_distribution.lower() + if distribution == _accepted_distributions[0]: # laplace + per_class_epsilon = self.epsilon / (output_dim) + l2_sensitivity = (2 * + loss.lipchitz_constant(self.class_weights)) / \ + (loss.gamma() * self.n_samples * self.batch_size) + unit_vector = tf.random.normal(shape=(input_dim, output_dim), + mean=0, + seed=1, + stddev=1.0, + dtype=self.dtype) + unit_vector = unit_vector / tf.math.sqrt( + tf.reduce_sum(tf.math.square(unit_vector), axis=0) + ) + + beta = l2_sensitivity / per_class_epsilon + alpha = input_dim # input_dim + gamma = tf.random.gamma([output_dim], + alpha, + beta=1 / beta, + seed=1, + dtype=self.dtype + ) + return unit_vector * gamma + raise NotImplementedError('Noise distribution: {0} is not ' + 'a valid distribution'.format(distribution)) + + def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.from_config(*args, **kwargs) + + def __getattr__(self, name): + """return _internal_optimizer off self instance, and everything else + from the _internal_optimizer instance. + + Args: + name: + + Returns: attribute from Bolton if specified to come from self, else + from _internal_optimizer. + + """ + if name == '_private_attributes' or name in self._private_attributes: + return getattr(self, name) + optim = object.__getattribute__(self, '_internal_optimizer') + try: + return object.__getattribute__(optim, name) + except AttributeError: + raise AttributeError( + "Neither '{0}' nor '{1}' object has attribute '{2}'" + "".format(self.__class__.__name__, + self._internal_optimizer.__class__.__name__, + name + ) + ) + + def __setattr__(self, key, value): + """ Set attribute to self instance if its the internal optimizer. + Reroute everything else to the _internal_optimizer. + + Args: + key: attribute name + value: attribute value + + Returns: + + """ + if key == '_private_attributes': + object.__setattr__(self, key, value) + elif key in self._private_attributes: + object.__setattr__(self, key, value) + else: + setattr(self._internal_optimizer, key, value) + + def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access + + def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access + + def get_updates(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + out = self._internal_optimizer.get_updates(loss, params) + self.project_weights_to_r() + return out + + def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + out = self._internal_optimizer.apply_gradients(*args, **kwargs) + self.project_weights_to_r() + return out + + def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + out = self._internal_optimizer.minimize(*args, **kwargs) + self.project_weights_to_r() + return out + + def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access + + def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer. + """ + return self._internal_optimizer.get_gradients(*args, **kwargs) + + def __enter__(self): + """Context manager call at the beginning of with statement. + + Returns: + self, to be used in context manager + """ + self._is_init = True + return self + + def __call__(self, + noise_distribution: str, + epsilon: float, + layers: list, + class_weights, + n_samples, + batch_size + ): + """Entry point from context. Accepts required values for bolton method and + stores them on the optimizer for use throughout fitting. + + Args: + noise_distribution: the noise distribution to pick. + see _accepted_distributions and get_noise for + possible values. + epsilon: privacy parameter. Lower gives more privacy but less utility. + layers: list of Keras/Tensorflow layers. Can be found as model.layers + class_weights: class_weights used, which may either be a scalar or 1D + tensor with dim == n_classes. + n_samples number of rows/individual samples in the training set + batch_size: batch size used. + """ + if epsilon <= 0: + raise ValueError('Detected epsilon: {0}. ' + 'Valid range is 0 < epsilon .wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 20 samples\n", + "Epoch 1/2\n", + "20/20 [==============================] - 0s 4ms/sample - loss: 0.8146\n", + "Epoch 2/2\n", + "20/20 [==============================] - 0s 94us/sample - loss: 0.5699\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# required parameters\n", + "class_weight = None # default, use .calculate_class_weight to specify other values\n", + "batch_size = None # default, if it cannot be inferred, specify this\n", + "n_samples = None # default, if it cannot be iferred, specify this\n", + "# privacy parameters\n", + "epsilon = 2\n", + "noise_distribution = 'laplace'\n", + "\n", + "bolt.fit(x, \n", + " y, \n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " epochs=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We may also train a generator object, or try different optimizers and loss functions. Below, we will see that we must pass the number of samples as the fit method is unable to infer it for a generator." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer2 = tf.optimizers.Adam()\n", + "bolt.compile(optimizer2, loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Could not infer the number of samples. Please pass this in using n_samples.\n" + ] + } + ], + "source": [ + "# required parameters\n", + "class_weight = None # default, use .calculate_class_weight to specify other values\n", + "batch_size = None # default, if it cannot be inferred, specify this\n", + "n_samples = None # default, if it cannot be iferred, specify this\n", + "# privacy parameters\n", + "epsilon = 2\n", + "noise_distribution = 'laplace'\n", + "try:\n", + " bolt.fit(generator,\n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " verbose=0\n", + " )\n", + "except ValueError as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And now, re running with the parameter set." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "n_samples = 20\n", + "bolt.fit(generator,\n", + " epsilon=epsilon, \n", + " class_weight=class_weight, \n", + " batch_size=batch_size, \n", + " n_samples=n_samples,\n", + " noise_distribution=noise_distribution,\n", + " verbose=0\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You don't have to use the bolton model to use the Bolton method. There are only a few requirements:\n", + "1. make sure any requirements from the loss are implemented in the model.\n", + "2. instantiate the optimizer and use it as a context around your fit operation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from privacy.bolton.optimizers import Bolton" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here, we create our own model and setup the Bolton optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class TestModel(tf.keras.Model):\n", + " def __init__(self, reg_layer, n_outputs=1):\n", + " super(TestModel, self).__init__(name='test')\n", + " self.output_layer = tf.keras.layers.Dense(n_outputs,\n", + " kernel_regularizer=reg_layer\n", + " )\n", + " \n", + " def call(self, inputs):\n", + " return self.output_layer(inputs)\n", + "\n", + "optimizer = tf.optimizers.SGD()\n", + "loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)\n", + "optimizer = Bolton(optimizer, loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, we instantiate our model and check for 1. Since our loss requires L2 regularization over the kernel, we will pass it to the model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "n_outputs = 1 # parameter for model and optimizer context.\n", + "test_model = TestModel(loss.kernel_regularizer(), n_outputs)\n", + "test_model.compile(optimizer, loss)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We comply with 2., and use the Bolton Optimizer as a context around the fit method." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 20 samples\n", + "Epoch 1/2\n", + "20/20 [==============================] - 0s 3ms/sample - loss: 0.9096\n", + "Epoch 2/2\n", + "20/20 [==============================] - 0s 430us/sample - loss: 0.5275\n" + ] + } + ], + "source": [ + "# parameters for context\n", + "noise_distribution = 'laplace'\n", + "epsilon = 2\n", + "class_weights = 1 # Previosuly, the fit method auto-detected the class_weights.\n", + "# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.\n", + "n_samples = 20\n", + "batch_size = 5\n", + "\n", + "with optimizer(\n", + " noise_distribution=noise_distribution,\n", + " epsilon=epsilon,\n", + " layers=test_model.layers,\n", + " class_weights=class_weights, \n", + " n_samples=n_samples,\n", + " batch_size=batch_size\n", + ") as _:\n", + " test_model.fit(x, y, batch_size=batch_size, epochs=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "source": [], + "metadata": { + "collapsed": false + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} \ No newline at end of file