forked from 626_privacy/tensorflow_privacy
Merge pull request #53 from georgianpartners:master
PiperOrigin-RevId: 260990063
This commit is contained in:
commit
9fe5e91de4
10 changed files with 2815 additions and 0 deletions
|
@ -41,3 +41,9 @@ else:
|
||||||
from privacy.optimizers.dp_optimizer import DPAdamOptimizer
|
from privacy.optimizers.dp_optimizer import DPAdamOptimizer
|
||||||
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
|
||||||
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
|
||||||
|
|
||||||
|
from privacy.bolt_on.models import BoltOnModel
|
||||||
|
from privacy.bolt_on.optimizers import BoltOn
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
|
||||||
|
from privacy.bolt_on.losses import StrongConvexHuber
|
||||||
|
|
57
privacy/bolt_on/README.md
Normal file
57
privacy/bolt_on/README.md
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
# BoltOn Subpackage
|
||||||
|
|
||||||
|
This package contains source code for the BoltOn method, a particular
|
||||||
|
differential-privacy (DP) technique that uses output perturbations and
|
||||||
|
leverages additional assumptions to provide a new way of approaching the
|
||||||
|
privacy guarantees.
|
||||||
|
|
||||||
|
## BoltOn Description
|
||||||
|
|
||||||
|
This method uses 4 key steps to achieve privacy guarantees:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R, the radius of the hypothesis space,
|
||||||
|
after each batch. This value is configurable by the user.
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Uses a strongly convex loss function (see compile)
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf
|
||||||
|
|
||||||
|
## Why BoltOn?
|
||||||
|
|
||||||
|
The major difference for the BoltOn method is that it injects noise post model
|
||||||
|
convergence, rather than noising gradients or weights during training. This
|
||||||
|
approach requires some additional constraints listed in the Description.
|
||||||
|
Should the use-case and model satisfy these constraints, this is another
|
||||||
|
approach that can be trained to maximize utility while maintaining the privacy.
|
||||||
|
The paper describes in detail the advantages and disadvantages of this approach
|
||||||
|
and its results compared to some other methods, namely noising at each iteration
|
||||||
|
and no noising.
|
||||||
|
|
||||||
|
## Tutorials
|
||||||
|
|
||||||
|
This package has a tutorial that can be found in the root tutorials directory,
|
||||||
|
under `bolton_tutorial.py`.
|
||||||
|
|
||||||
|
## Contribution
|
||||||
|
|
||||||
|
This package was initially contributed by Georgian Partners with the hope of
|
||||||
|
growing the tensorflow/privacy library. There are several rich use cases for
|
||||||
|
delta-epsilon privacy in machine learning, some of which can be explored here:
|
||||||
|
https://medium.com/apache-mxnet/epsilon-differential-privacy-for-machine-learning-using-mxnet-a4270fe3865e
|
||||||
|
https://arxiv.org/pdf/1811.04911.pdf
|
||||||
|
|
||||||
|
## Contacts
|
||||||
|
|
||||||
|
In addition to the maintainers of tensorflow/privacy listed in the root
|
||||||
|
README.md, please feel free to contact members of Georgian Partners. In
|
||||||
|
particular,
|
||||||
|
|
||||||
|
* Georgian Partners(@georgianpartners)
|
||||||
|
* Ji Chao Zhang(@Jichaogp)
|
||||||
|
* Christopher Choquette(@cchoquette)
|
||||||
|
|
||||||
|
## Copyright
|
||||||
|
|
||||||
|
Copyright 2019 - Google LLC
|
29
privacy/bolt_on/__init__.py
Normal file
29
privacy/bolt_on/__init__.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright 2019, The TensorFlow Privacy Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""BoltOn Method for privacy."""
|
||||||
|
import sys
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
|
||||||
|
raise ImportError("Please upgrade your version "
|
||||||
|
"of tensorflow from: {0} to at least 2.0.0 to "
|
||||||
|
"use privacy/bolt_on".format(LooseVersion(tf.__version__)))
|
||||||
|
if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
from privacy.bolt_on.models import BoltOnModel # pylint: disable=g-import-not-at-top
|
||||||
|
from privacy.bolt_on.optimizers import BoltOn # pylint: disable=g-import-not-at-top
|
||||||
|
from privacy.bolt_on.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top
|
||||||
|
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top
|
304
privacy/bolt_on/losses.py
Normal file
304
privacy/bolt_on/losses.py
Normal file
|
@ -0,0 +1,304 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Loss functions for BoltOn method."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import ops as _ops
|
||||||
|
from tensorflow.python.keras import losses
|
||||||
|
from tensorflow.python.keras.regularizers import L1L2
|
||||||
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
|
||||||
|
|
||||||
|
class StrongConvexMixin: # pylint: disable=old-style-class
|
||||||
|
"""Strong Convex Mixin base class.
|
||||||
|
|
||||||
|
Strong Convex Mixin base class for any loss function that will be used with
|
||||||
|
BoltOn model. Subclasses must be strongly convex and implement the
|
||||||
|
associated constants. They must also conform to the requirements of tf losses
|
||||||
|
(see super class).
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et. al.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""Radius, R, of the hypothesis space W.
|
||||||
|
|
||||||
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
R
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""Smoothness, beta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
|
dimensionality is equal to the number of outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Beta
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
|
||||||
|
Returns: L
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("lipchitz constant not implemented for "
|
||||||
|
"StrongConvex Loss"
|
||||||
|
"function: %s" % str(self.__class__.__name__))
|
||||||
|
|
||||||
|
def kernel_regularizer(self):
|
||||||
|
"""Returns the kernel_regularizer to be used.
|
||||||
|
|
||||||
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def max_class_weight(self, class_weight, dtype):
|
||||||
|
"""The maximum weighting in class weights (max value) as a scalar tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
dtype: the data type for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
|
"""
|
||||||
|
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
|
||||||
|
return tf.math.reduce_max(class_weight)
|
||||||
|
|
||||||
|
|
||||||
|
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
|
||||||
|
"""Strong Convex version of Huber loss using l2 weight regularization."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
reg_lambda,
|
||||||
|
c_arg,
|
||||||
|
radius_constant,
|
||||||
|
delta,
|
||||||
|
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
dtype=tf.float32):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: Weight regularization constant
|
||||||
|
c_arg: Penalty parameter C of the loss term
|
||||||
|
radius_constant: constant defining the length of the radius
|
||||||
|
delta: delta value in huber loss. When to switch from quadratic to
|
||||||
|
absolute deviation.
|
||||||
|
reduction: reduction type to use. See super class
|
||||||
|
dtype: tf datatype to use for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
if c_arg <= 0:
|
||||||
|
raise ValueError("c: {0}, should be >= 0".format(c_arg))
|
||||||
|
if reg_lambda <= 0:
|
||||||
|
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||||
|
if radius_constant <= 0:
|
||||||
|
raise ValueError("radius_constant: {0}, should be >= 0".format(
|
||||||
|
radius_constant
|
||||||
|
))
|
||||||
|
if delta <= 0:
|
||||||
|
raise ValueError("delta: {0}, should be >= 0".format(
|
||||||
|
delta
|
||||||
|
))
|
||||||
|
self.C = c_arg # pylint: disable=invalid-name
|
||||||
|
self.delta = delta
|
||||||
|
self.radius_constant = radius_constant
|
||||||
|
self.dtype = dtype
|
||||||
|
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||||
|
super(StrongConvexHuber, self).__init__(
|
||||||
|
name="strongconvexhuber",
|
||||||
|
reduction=reduction,
|
||||||
|
)
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Computes loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values. One hot encoded using -1 and 1.
|
||||||
|
y_pred: The predicted values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
h = self.delta
|
||||||
|
z = y_pred * y_true
|
||||||
|
one = tf.constant(1, dtype=self.dtype)
|
||||||
|
four = tf.constant(4, dtype=self.dtype)
|
||||||
|
|
||||||
|
if z > one + h: # pylint: disable=no-else-return
|
||||||
|
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
|
||||||
|
elif tf.math.abs(one - z) <= h:
|
||||||
|
return one / (four * h) * tf.math.pow(one + h - z, 2)
|
||||||
|
return one - z
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""See super class."""
|
||||||
|
return self.radius_constant / self.reg_lambda
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""See super class."""
|
||||||
|
return self.reg_lambda
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""See super class."""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
delta = _ops.convert_to_tensor_v2(self.delta,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
return self.C * max_class_weight / (delta *
|
||||||
|
tf.constant(2, dtype=self.dtype)) + \
|
||||||
|
self.reg_lambda
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
"""See super class."""
|
||||||
|
# if class_weight is provided,
|
||||||
|
# it should be a vector of the same size of number of classes
|
||||||
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
lc = self.C * max_class_weight + \
|
||||||
|
self.reg_lambda * self.radius()
|
||||||
|
return lc
|
||||||
|
|
||||||
|
def kernel_regularizer(self):
|
||||||
|
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
|
||||||
|
|
||||||
|
L2 regularization is required for this loss function to be strongly convex.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The L2 regularizer layer for this loss function, with regularizer constant
|
||||||
|
set to half the 0.5 * reg_lambda.
|
||||||
|
"""
|
||||||
|
return L1L2(l2=self.reg_lambda/2)
|
||||||
|
|
||||||
|
|
||||||
|
class StrongConvexBinaryCrossentropy(
|
||||||
|
losses.BinaryCrossentropy,
|
||||||
|
StrongConvexMixin
|
||||||
|
):
|
||||||
|
"""Strongly Convex BinaryCrossentropy loss using l2 weight regularization."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
reg_lambda,
|
||||||
|
c_arg,
|
||||||
|
radius_constant,
|
||||||
|
from_logits=True,
|
||||||
|
label_smoothing=0,
|
||||||
|
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
|
||||||
|
dtype=tf.float32):
|
||||||
|
"""StrongConvexBinaryCrossentropy class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: Weight regularization constant
|
||||||
|
c_arg: Penalty parameter C of the loss term
|
||||||
|
radius_constant: constant defining the length of the radius
|
||||||
|
from_logits: True if the input are unscaled logits. False if they are
|
||||||
|
already scaled.
|
||||||
|
label_smoothing: amount of smoothing to perform on labels
|
||||||
|
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
|
||||||
|
impact of this parameter's effect on privacy is not known and thus the
|
||||||
|
default should be used.
|
||||||
|
reduction: reduction type to use. See super class
|
||||||
|
dtype: tf datatype to use for tensor conversions.
|
||||||
|
"""
|
||||||
|
if label_smoothing != 0:
|
||||||
|
logging.warning("The impact of label smoothing on privacy is unknown. "
|
||||||
|
"Use label smoothing at your own risk as it may not "
|
||||||
|
"guarantee privacy.")
|
||||||
|
|
||||||
|
if reg_lambda <= 0:
|
||||||
|
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
|
||||||
|
if c_arg <= 0:
|
||||||
|
raise ValueError("c: {0}, should be >= 0".format(c_arg))
|
||||||
|
if radius_constant <= 0:
|
||||||
|
raise ValueError("radius_constant: {0}, should be >= 0".format(
|
||||||
|
radius_constant
|
||||||
|
))
|
||||||
|
self.dtype = dtype
|
||||||
|
self.C = c_arg # pylint: disable=invalid-name
|
||||||
|
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
|
||||||
|
super(StrongConvexBinaryCrossentropy, self).__init__(
|
||||||
|
reduction=reduction,
|
||||||
|
name="strongconvexbinarycrossentropy",
|
||||||
|
from_logits=from_logits,
|
||||||
|
label_smoothing=label_smoothing,
|
||||||
|
)
|
||||||
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Computes loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_true: Ground truth values.
|
||||||
|
y_pred: The predicted values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss values per sample.
|
||||||
|
"""
|
||||||
|
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
|
||||||
|
loss = loss * self.C
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""See super class."""
|
||||||
|
return self.radius_constant / self.reg_lambda
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""See super class."""
|
||||||
|
return self.reg_lambda
|
||||||
|
|
||||||
|
def beta(self, class_weight):
|
||||||
|
"""See super class."""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
return self.C * max_class_weight + self.reg_lambda
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight):
|
||||||
|
"""See super class."""
|
||||||
|
max_class_weight = self.max_class_weight(class_weight, self.dtype)
|
||||||
|
return self.C * max_class_weight + self.reg_lambda * self.radius()
|
||||||
|
|
||||||
|
def kernel_regularizer(self):
|
||||||
|
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
|
||||||
|
|
||||||
|
L2 regularization is required for this loss function to be strongly convex.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The L2 regularizer layer for this loss function, with regularizer constant
|
||||||
|
set to half the 0.5 * reg_lambda.
|
||||||
|
"""
|
||||||
|
return L1L2(l2=self.reg_lambda/2)
|
431
privacy/bolt_on/losses_test.py
Normal file
431
privacy/bolt_on/losses_test.py
Normal file
|
@ -0,0 +1,431 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Unit testing for losses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from contextlib import contextmanager # pylint: disable=g-importing-member
|
||||||
|
from io import StringIO # pylint: disable=g-importing-member
|
||||||
|
import sys
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras.regularizers import L1L2
|
||||||
|
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
|
||||||
|
from privacy.bolt_on.losses import StrongConvexHuber
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def captured_output():
|
||||||
|
"""Capture std_out and std_err within context."""
|
||||||
|
new_out, new_err = StringIO(), StringIO()
|
||||||
|
old_out, old_err = sys.stdout, sys.stderr
|
||||||
|
try:
|
||||||
|
sys.stdout, sys.stderr = new_out, new_err
|
||||||
|
yield sys.stdout, sys.stderr
|
||||||
|
finally:
|
||||||
|
sys.stdout, sys.stderr = old_out, old_err
|
||||||
|
|
||||||
|
|
||||||
|
class StrongConvexMixinTests(keras_parameterized.TestCase):
|
||||||
|
"""Tests for the StrongConvexMixin."""
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'beta not implemented',
|
||||||
|
'fn': 'beta',
|
||||||
|
'args': [1]},
|
||||||
|
{'testcase_name': 'gamma not implemented',
|
||||||
|
'fn': 'gamma',
|
||||||
|
'args': []},
|
||||||
|
{'testcase_name': 'lipchitz not implemented',
|
||||||
|
'fn': 'lipchitz_constant',
|
||||||
|
'args': [1]},
|
||||||
|
{'testcase_name': 'radius not implemented',
|
||||||
|
'fn': 'radius',
|
||||||
|
'args': []},
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_not_implemented(self, fn, args):
|
||||||
|
"""Test that the given fn's are not implemented on the mixin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: fn on Mixin to test
|
||||||
|
args: arguments to fn of Mixin
|
||||||
|
"""
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
loss = StrongConvexMixin()
|
||||||
|
getattr(loss, fn, None)(*args)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'radius not implemented',
|
||||||
|
'fn': 'kernel_regularizer',
|
||||||
|
'args': []},
|
||||||
|
])
|
||||||
|
def test_return_none(self, fn, args):
|
||||||
|
"""Test that fn of Mixin returns None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: fn of Mixin to test
|
||||||
|
args: arguments to fn of Mixin
|
||||||
|
"""
|
||||||
|
loss = StrongConvexMixin()
|
||||||
|
ret = getattr(loss, fn, None)(*args)
|
||||||
|
self.assertEqual(ret, None)
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
|
||||||
|
"""tests for BinaryCrossesntropy StrongConvex loss."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'normal',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'C': 1,
|
||||||
|
'radius_constant': 1
|
||||||
|
}, # pylint: disable=invalid-name
|
||||||
|
])
|
||||||
|
def test_init_params(self, reg_lambda, C, radius_constant):
|
||||||
|
"""Test initialization for given arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
|
C: initialization value for C arg
|
||||||
|
radius_constant: initialization value for radius_constant arg
|
||||||
|
"""
|
||||||
|
# test valid domains for each variable
|
||||||
|
loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'negative c',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'C': -1,
|
||||||
|
'radius_constant': 1
|
||||||
|
},
|
||||||
|
{'testcase_name': 'negative radius',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'C': 1,
|
||||||
|
'radius_constant': -1
|
||||||
|
},
|
||||||
|
{'testcase_name': 'negative lambda',
|
||||||
|
'reg_lambda': -1,
|
||||||
|
'C': 1,
|
||||||
|
'radius_constant': 1
|
||||||
|
}, # pylint: disable=invalid-name
|
||||||
|
])
|
||||||
|
def test_bad_init_params(self, reg_lambda, C, radius_constant):
|
||||||
|
"""Test invalid domain for given params. Should return ValueError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
|
C: initialization value for C arg
|
||||||
|
radius_constant: initialization value for radius_constant arg
|
||||||
|
"""
|
||||||
|
# test valid domains for each variable
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
# [] for compatibility with tensorflow loss calculation
|
||||||
|
{'testcase_name': 'both positive',
|
||||||
|
'logits': [10000],
|
||||||
|
'y_true': [1],
|
||||||
|
'result': 0,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'positive gradient negative logits',
|
||||||
|
'logits': [-10000],
|
||||||
|
'y_true': [1],
|
||||||
|
'result': 10000,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'positivee gradient positive logits',
|
||||||
|
'logits': [10000],
|
||||||
|
'y_true': [0],
|
||||||
|
'result': 10000,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'both negative',
|
||||||
|
'logits': [-10000],
|
||||||
|
'y_true': [0],
|
||||||
|
'result': 0
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_calculation(self, logits, y_true, result):
|
||||||
|
"""Test the call method to ensure it returns the correct value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: unscaled output of model
|
||||||
|
y_true: label
|
||||||
|
result: correct loss calculation value
|
||||||
|
"""
|
||||||
|
logits = tf.Variable(logits, False, dtype=tf.float32)
|
||||||
|
y_true = tf.Variable(y_true, False, dtype=tf.float32)
|
||||||
|
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
|
||||||
|
loss = loss(y_true, logits)
|
||||||
|
self.assertEqual(loss.numpy(), result)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'beta',
|
||||||
|
'init_args': [1, 1, 1],
|
||||||
|
'fn': 'beta',
|
||||||
|
'args': [1],
|
||||||
|
'result': tf.constant(2, dtype=tf.float32)
|
||||||
|
},
|
||||||
|
{'testcase_name': 'gamma',
|
||||||
|
'fn': 'gamma',
|
||||||
|
'init_args': [1, 1, 1],
|
||||||
|
'args': [],
|
||||||
|
'result': tf.constant(1, dtype=tf.float32),
|
||||||
|
},
|
||||||
|
{'testcase_name': 'lipchitz constant',
|
||||||
|
'fn': 'lipchitz_constant',
|
||||||
|
'init_args': [1, 1, 1],
|
||||||
|
'args': [1],
|
||||||
|
'result': tf.constant(2, dtype=tf.float32),
|
||||||
|
},
|
||||||
|
{'testcase_name': 'kernel regularizer',
|
||||||
|
'fn': 'kernel_regularizer',
|
||||||
|
'init_args': [1, 1, 1],
|
||||||
|
'args': [],
|
||||||
|
'result': L1L2(l2=0.5),
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_fns(self, init_args, fn, args, result):
|
||||||
|
"""Test that fn of BinaryCrossentropy loss returns the correct result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_args: init values for loss instance
|
||||||
|
fn: the fn to test
|
||||||
|
args: the arguments to above function
|
||||||
|
result: the correct result from the fn
|
||||||
|
"""
|
||||||
|
loss = StrongConvexBinaryCrossentropy(*init_args)
|
||||||
|
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
|
||||||
|
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
|
||||||
|
expected = expected.numpy()
|
||||||
|
result = result.numpy()
|
||||||
|
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
|
||||||
|
expected = expected.l2
|
||||||
|
result = result.l2
|
||||||
|
self.assertEqual(expected, result)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'label_smoothing',
|
||||||
|
'init_args': [1, 1, 1, True, 0.1],
|
||||||
|
'fn': None,
|
||||||
|
'args': None,
|
||||||
|
'print_res': 'The impact of label smoothing on privacy is unknown.'
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_prints(self, init_args, fn, args, print_res):
|
||||||
|
"""Test logger warning from StrongConvexBinaryCrossentropy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_args: arguments to init the object with.
|
||||||
|
fn: function to test
|
||||||
|
args: arguments to above function
|
||||||
|
print_res: print result that should have been printed.
|
||||||
|
"""
|
||||||
|
with captured_output() as (out, err): # pylint: disable=unused-variable
|
||||||
|
loss = StrongConvexBinaryCrossentropy(*init_args)
|
||||||
|
if fn is not None:
|
||||||
|
getattr(loss, fn, lambda *arguments: print('error'))(*args)
|
||||||
|
self.assertRegexMatch(err.getvalue().strip(), [print_res])
|
||||||
|
|
||||||
|
|
||||||
|
class HuberTests(keras_parameterized.TestCase):
|
||||||
|
"""tests for BinaryCrossesntropy StrongConvex loss."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'normal',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'c': 1,
|
||||||
|
'radius_constant': 1,
|
||||||
|
'delta': 1,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||||
|
"""Test initialization for given arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
|
c: initialization value for C arg
|
||||||
|
radius_constant: initialization value for radius_constant arg
|
||||||
|
delta: the delta parameter for the huber loss
|
||||||
|
"""
|
||||||
|
# test valid domains for each variable
|
||||||
|
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
|
||||||
|
self.assertIsInstance(loss, StrongConvexHuber)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'negative c',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'c': -1,
|
||||||
|
'radius_constant': 1,
|
||||||
|
'delta': 1
|
||||||
|
},
|
||||||
|
{'testcase_name': 'negative radius',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'c': 1,
|
||||||
|
'radius_constant': -1,
|
||||||
|
'delta': 1
|
||||||
|
},
|
||||||
|
{'testcase_name': 'negative lambda',
|
||||||
|
'reg_lambda': -1,
|
||||||
|
'c': 1,
|
||||||
|
'radius_constant': 1,
|
||||||
|
'delta': 1
|
||||||
|
},
|
||||||
|
{'testcase_name': 'negative delta',
|
||||||
|
'reg_lambda': 1,
|
||||||
|
'c': 1,
|
||||||
|
'radius_constant': 1,
|
||||||
|
'delta': -1
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
|
||||||
|
"""Test invalid domain for given params. Should return ValueError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reg_lambda: initialization value for reg_lambda arg
|
||||||
|
c: initialization value for C arg
|
||||||
|
radius_constant: initialization value for radius_constant arg
|
||||||
|
delta: the delta parameter for the huber loss
|
||||||
|
"""
|
||||||
|
# test valid domains for each variable
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
StrongConvexHuber(reg_lambda, c, radius_constant, delta)
|
||||||
|
|
||||||
|
# test the bounds and test varied delta's
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
|
||||||
|
'logits': 2.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 1,
|
||||||
|
'result': 0,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
|
||||||
|
'logits': 1.9,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 1,
|
||||||
|
'result': 0.01*0.25,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary',
|
||||||
|
'logits': 0.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 1,
|
||||||
|
'result': 1.9**2 * 0.25,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary',
|
||||||
|
'logits': -0.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 1,
|
||||||
|
'result': 1.1,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary',
|
||||||
|
'logits': 3.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 2,
|
||||||
|
'result': 0,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary',
|
||||||
|
'logits': 2.9,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 2,
|
||||||
|
'result': 0.01*0.125,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary',
|
||||||
|
'logits': 1.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 2,
|
||||||
|
'result': 1.9**2 * 0.125,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary',
|
||||||
|
'logits': -1.1,
|
||||||
|
'y_true': 1,
|
||||||
|
'delta': 2,
|
||||||
|
'result': 2.1,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary',
|
||||||
|
'logits': -2.1,
|
||||||
|
'y_true': -1,
|
||||||
|
'delta': 1,
|
||||||
|
'result': 0,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_calculation(self, logits, y_true, delta, result):
|
||||||
|
"""Test the call method to ensure it returns the correct value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: unscaled output of model
|
||||||
|
y_true: label
|
||||||
|
delta: delta value for StrongConvexHuber loss.
|
||||||
|
result: correct loss calculation value
|
||||||
|
"""
|
||||||
|
logits = tf.Variable(logits, False, dtype=tf.float32)
|
||||||
|
y_true = tf.Variable(y_true, False, dtype=tf.float32)
|
||||||
|
loss = StrongConvexHuber(0.00001, 1, 1, delta)
|
||||||
|
loss = loss(y_true, logits)
|
||||||
|
self.assertAllClose(loss.numpy(), result)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'beta',
|
||||||
|
'init_args': [1, 1, 1, 1],
|
||||||
|
'fn': 'beta',
|
||||||
|
'args': [1],
|
||||||
|
'result': tf.Variable(1.5, dtype=tf.float32)
|
||||||
|
},
|
||||||
|
{'testcase_name': 'gamma',
|
||||||
|
'fn': 'gamma',
|
||||||
|
'init_args': [1, 1, 1, 1],
|
||||||
|
'args': [],
|
||||||
|
'result': tf.Variable(1, dtype=tf.float32),
|
||||||
|
},
|
||||||
|
{'testcase_name': 'lipchitz constant',
|
||||||
|
'fn': 'lipchitz_constant',
|
||||||
|
'init_args': [1, 1, 1, 1],
|
||||||
|
'args': [1],
|
||||||
|
'result': tf.Variable(2, dtype=tf.float32),
|
||||||
|
},
|
||||||
|
{'testcase_name': 'kernel regularizer',
|
||||||
|
'fn': 'kernel_regularizer',
|
||||||
|
'init_args': [1, 1, 1, 1],
|
||||||
|
'args': [],
|
||||||
|
'result': L1L2(l2=0.5),
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_fns(self, init_args, fn, args, result):
|
||||||
|
"""Test that fn of BinaryCrossentropy loss returns the correct result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_args: init values for loss instance
|
||||||
|
fn: the fn to test
|
||||||
|
args: the arguments to above function
|
||||||
|
result: the correct result from the fn
|
||||||
|
"""
|
||||||
|
loss = StrongConvexHuber(*init_args)
|
||||||
|
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
|
||||||
|
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
|
||||||
|
expected = expected.numpy()
|
||||||
|
result = result.numpy()
|
||||||
|
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
|
||||||
|
expected = expected.l2
|
||||||
|
result = result.l2
|
||||||
|
self.assertEqual(expected, result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
297
privacy/bolt_on/models.py
Normal file
297
privacy/bolt_on/models.py
Normal file
|
@ -0,0 +1,297 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""BoltOn model for Bolt-on method of differentially private ML."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import ops as _ops
|
||||||
|
from tensorflow.python.keras import optimizers
|
||||||
|
from tensorflow.python.keras.models import Model
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
from privacy.bolt_on.optimizers import BoltOn
|
||||||
|
|
||||||
|
|
||||||
|
class BoltOnModel(Model): # pylint: disable=abstract-method
|
||||||
|
"""BoltOn episilon-delta differential privacy model.
|
||||||
|
|
||||||
|
The privacy guarantees are dependent on the noise that is sampled. Please
|
||||||
|
see the paper linked below for more details.
|
||||||
|
|
||||||
|
Uses 4 key steps to achieve privacy guarantees:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Use a strongly convex loss function (see compile)
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et al.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_outputs,
|
||||||
|
seed=1,
|
||||||
|
dtype=tf.float32):
|
||||||
|
"""Private constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output classes to predict.
|
||||||
|
seed: random seed to use
|
||||||
|
dtype: data type to use for tensors
|
||||||
|
"""
|
||||||
|
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
|
||||||
|
if n_outputs <= 0:
|
||||||
|
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
|
||||||
|
n_outputs
|
||||||
|
))
|
||||||
|
self.n_outputs = n_outputs
|
||||||
|
self.seed = seed
|
||||||
|
self._layers_instantiated = False
|
||||||
|
self._dtype = dtype
|
||||||
|
|
||||||
|
def call(self, inputs): # pylint: disable=arguments-differ
|
||||||
|
"""Forward pass of network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: inputs to neural network
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output logits for the given inputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self.output_layer(inputs)
|
||||||
|
|
||||||
|
def compile(self,
|
||||||
|
optimizer,
|
||||||
|
loss,
|
||||||
|
kernel_initializer=tf.initializers.GlorotUniform,
|
||||||
|
**kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""See super class. Default optimizer used in BoltOn method is SGD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: The optimizer to use. This will be automatically wrapped
|
||||||
|
with the BoltOn Optimizer.
|
||||||
|
loss: The loss function to use. Must be a StrongConvex loss (extend the
|
||||||
|
StrongConvexMixin).
|
||||||
|
kernel_initializer: The kernel initializer to use for the single layer.
|
||||||
|
**kwargs: kwargs to keras Model.compile. See super.
|
||||||
|
"""
|
||||||
|
if not isinstance(loss, StrongConvexMixin):
|
||||||
|
raise ValueError('loss function must be a Strongly Convex and therefore '
|
||||||
|
'extend the StrongConvexMixin.')
|
||||||
|
if not self._layers_instantiated: # compile may be called multiple times
|
||||||
|
# for instance, if the input/outputs are not defined until fit.
|
||||||
|
self.output_layer = tf.keras.layers.Dense(
|
||||||
|
self.n_outputs,
|
||||||
|
kernel_regularizer=loss.kernel_regularizer(),
|
||||||
|
kernel_initializer=kernel_initializer(),
|
||||||
|
)
|
||||||
|
self._layers_instantiated = True
|
||||||
|
if not isinstance(optimizer, BoltOn):
|
||||||
|
optimizer = optimizers.get(optimizer)
|
||||||
|
optimizer = BoltOn(optimizer, loss)
|
||||||
|
|
||||||
|
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
x=None,
|
||||||
|
y=None,
|
||||||
|
batch_size=None,
|
||||||
|
class_weight=None,
|
||||||
|
n_samples=None,
|
||||||
|
epsilon=2,
|
||||||
|
noise_distribution='laplace',
|
||||||
|
steps_per_epoch=None,
|
||||||
|
**kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
|
||||||
|
|
||||||
|
Note, inputs must be normalized s.t. ||x|| < 1.
|
||||||
|
Requirements are as follows:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Use a strongly convex loss function (see compile)
|
||||||
|
See super implementation for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Inputs to fit on, see super.
|
||||||
|
y: Labels to fit on, see super.
|
||||||
|
batch_size: The batch size to use for training, see super.
|
||||||
|
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||||
|
whose dim == n_classes.
|
||||||
|
n_samples: the number of individual samples in x.
|
||||||
|
epsilon: privacy parameter, which trades off between utility an privacy.
|
||||||
|
See the bolt-on paper for more description.
|
||||||
|
noise_distribution: the distribution to pull noise from.
|
||||||
|
steps_per_epoch:
|
||||||
|
**kwargs: kwargs to keras Model.fit. See super.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output from super fit method.
|
||||||
|
"""
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight_ = self.calculate_class_weights(class_weight)
|
||||||
|
else:
|
||||||
|
class_weight_ = class_weight
|
||||||
|
if n_samples is not None:
|
||||||
|
data_size = n_samples
|
||||||
|
elif hasattr(x, 'shape'):
|
||||||
|
data_size = x.shape[0]
|
||||||
|
elif hasattr(x, '__len__'):
|
||||||
|
data_size = len(x)
|
||||||
|
else:
|
||||||
|
data_size = None
|
||||||
|
batch_size_ = self._validate_or_infer_batch_size(batch_size,
|
||||||
|
steps_per_epoch,
|
||||||
|
x)
|
||||||
|
# inferring batch_size to be passed to optimizer. batch_size must remain its
|
||||||
|
# initial value when passed to super().fit()
|
||||||
|
if batch_size_ is None:
|
||||||
|
raise ValueError('batch_size: {0} is an '
|
||||||
|
'invalid value'.format(batch_size_))
|
||||||
|
if data_size is None:
|
||||||
|
raise ValueError('Could not infer the number of samples. Please pass '
|
||||||
|
'this in using n_samples.')
|
||||||
|
with self.optimizer(noise_distribution,
|
||||||
|
epsilon,
|
||||||
|
self.layers,
|
||||||
|
class_weight_,
|
||||||
|
data_size,
|
||||||
|
batch_size_) as _:
|
||||||
|
out = super(BoltOnModel, self).fit(x=x,
|
||||||
|
y=y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
class_weight=class_weight,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
**kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def fit_generator(self,
|
||||||
|
generator,
|
||||||
|
class_weight=None,
|
||||||
|
noise_distribution='laplace',
|
||||||
|
epsilon=2,
|
||||||
|
n_samples=None,
|
||||||
|
steps_per_epoch=None,
|
||||||
|
**kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Fit with a generator.
|
||||||
|
|
||||||
|
This method is the same as fit except for when the passed dataset
|
||||||
|
is a generator. See super method and fit for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: Inputs generator following Tensorflow guidelines, see super.
|
||||||
|
class_weight: the class weights to be used. Can be a scalar or 1D tensor
|
||||||
|
whose dim == n_classes.
|
||||||
|
noise_distribution: the distribution to get noise from.
|
||||||
|
epsilon: privacy parameter, which trades off utility and privacy. See
|
||||||
|
BoltOn paper for more description.
|
||||||
|
n_samples: number of individual samples in x
|
||||||
|
steps_per_epoch: Number of steps per training epoch, see super.
|
||||||
|
**kwargs: **kwargs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Output from super fit_generator method.
|
||||||
|
"""
|
||||||
|
if class_weight is None:
|
||||||
|
class_weight = self.calculate_class_weights(class_weight)
|
||||||
|
if n_samples is not None:
|
||||||
|
data_size = n_samples
|
||||||
|
elif hasattr(generator, 'shape'):
|
||||||
|
data_size = generator.shape[0]
|
||||||
|
elif hasattr(generator, '__len__'):
|
||||||
|
data_size = len(generator)
|
||||||
|
else:
|
||||||
|
data_size = None
|
||||||
|
batch_size = self._validate_or_infer_batch_size(None,
|
||||||
|
steps_per_epoch,
|
||||||
|
generator)
|
||||||
|
with self.optimizer(noise_distribution,
|
||||||
|
epsilon,
|
||||||
|
self.layers,
|
||||||
|
class_weight,
|
||||||
|
data_size,
|
||||||
|
batch_size) as _:
|
||||||
|
out = super(BoltOnModel, self).fit_generator(
|
||||||
|
generator,
|
||||||
|
class_weight=class_weight,
|
||||||
|
steps_per_epoch=steps_per_epoch,
|
||||||
|
**kwargs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def calculate_class_weights(self,
|
||||||
|
class_weights=None,
|
||||||
|
class_counts=None,
|
||||||
|
num_classes=None):
|
||||||
|
"""Calculates class weighting to be used in training.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weights: str specifying type, array giving weights, or None.
|
||||||
|
class_counts: If class_weights is not None, then an array of
|
||||||
|
the number of samples for each class
|
||||||
|
num_classes: If class_weights is not None, then the number of
|
||||||
|
classes.
|
||||||
|
Returns:
|
||||||
|
class_weights as 1D tensor, to be passed to model's fit method.
|
||||||
|
"""
|
||||||
|
# Value checking
|
||||||
|
class_keys = ['balanced']
|
||||||
|
is_string = False
|
||||||
|
if isinstance(class_weights, str):
|
||||||
|
is_string = True
|
||||||
|
if class_weights not in class_keys:
|
||||||
|
raise ValueError('Detected string class_weights with '
|
||||||
|
'value: {0}, which is not one of {1}.'
|
||||||
|
'Please select a valid class_weight type'
|
||||||
|
'or pass an array'.format(class_weights,
|
||||||
|
class_keys))
|
||||||
|
if class_counts is None:
|
||||||
|
raise ValueError('Class counts must be provided if using '
|
||||||
|
'class_weights=%s' % class_weights)
|
||||||
|
class_counts_shape = tf.Variable(class_counts,
|
||||||
|
trainable=False,
|
||||||
|
dtype=self._dtype).shape
|
||||||
|
if len(class_counts_shape) != 1:
|
||||||
|
raise ValueError('class counts must be a 1D array.'
|
||||||
|
'Detected: {0}'.format(class_counts_shape))
|
||||||
|
if num_classes is None:
|
||||||
|
raise ValueError('num_classes must be provided if using '
|
||||||
|
'class_weights=%s' % class_weights)
|
||||||
|
elif class_weights is not None:
|
||||||
|
if num_classes is None:
|
||||||
|
raise ValueError('You must pass a value for num_classes if '
|
||||||
|
'creating an array of class_weights')
|
||||||
|
# performing class weight calculation
|
||||||
|
if class_weights is None:
|
||||||
|
class_weights = 1
|
||||||
|
elif is_string and class_weights == 'balanced':
|
||||||
|
num_samples = sum(class_counts)
|
||||||
|
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
|
||||||
|
class_counts),
|
||||||
|
self._dtype)
|
||||||
|
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
|
||||||
|
tf.Variable(weighted_counts, dtype=self._dtype)
|
||||||
|
else:
|
||||||
|
class_weights = _ops.convert_to_tensor_v2(class_weights)
|
||||||
|
if len(class_weights.shape) != 1:
|
||||||
|
raise ValueError('Detected class_weights shape: {0} instead of '
|
||||||
|
'1D array'.format(class_weights.shape))
|
||||||
|
if class_weights.shape[0] != num_classes:
|
||||||
|
raise ValueError(
|
||||||
|
'Detected array length: {0} instead of: {1}'.format(
|
||||||
|
class_weights.shape[0],
|
||||||
|
num_classes))
|
||||||
|
return class_weights
|
534
privacy/bolt_on/models_test.py
Normal file
534
privacy/bolt_on/models_test.py
Normal file
|
@ -0,0 +1,534 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Unit testing for models."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.framework import ops as _ops
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras import losses
|
||||||
|
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||||
|
from tensorflow.python.keras.regularizers import L1L2
|
||||||
|
from privacy.bolt_on import models
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
from privacy.bolt_on.optimizers import BoltOn
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
|
"""Test loss function for testing BoltOn model."""
|
||||||
|
|
||||||
|
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
|
||||||
|
super(TestLoss, self).__init__(name=name)
|
||||||
|
self.reg_lambda = reg_lambda
|
||||||
|
self.C = c_arg # pylint: disable=invalid-name
|
||||||
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""Radius, R, of the hypothesis space W.
|
||||||
|
|
||||||
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
radius
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def beta(self, class_weight): # pylint: disable=unused-argument
|
||||||
|
"""Smoothness, beta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
|
dimensionality is equal to the number of outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Beta
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
||||||
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
L
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Loss function that is minimized at the mean of the input points."""
|
||||||
|
return 0.5 * tf.reduce_sum(
|
||||||
|
tf.math.squared_difference(y_true, y_pred),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_class_weight(self, class_weight):
|
||||||
|
"""the maximum weighting in class weights (max value) as a scalar tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
|
"""
|
||||||
|
if class_weight is None:
|
||||||
|
return 1
|
||||||
|
raise ValueError('')
|
||||||
|
|
||||||
|
def kernel_regularizer(self):
|
||||||
|
"""Returns the kernel_regularizer to be used.
|
||||||
|
|
||||||
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
|
"""
|
||||||
|
return L1L2(l2=self.reg_lambda)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizer(OptimizerV2):
|
||||||
|
"""Test optimizer used for testing BoltOn model."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TestOptimizer, self).__init__('test')
|
||||||
|
|
||||||
|
def compute_gradients(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _create_slots(self, var):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, grad, handle):
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def _resource_apply_sparse(self, grad, handle, indices):
|
||||||
|
return grad
|
||||||
|
|
||||||
|
|
||||||
|
class InitTests(keras_parameterized.TestCase):
|
||||||
|
"""Tests for keras model initialization."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'normal',
|
||||||
|
'n_outputs': 1,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'many outputs',
|
||||||
|
'n_outputs': 100,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_init_params(self, n_outputs):
|
||||||
|
"""Test initialization of BoltOnModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
"""
|
||||||
|
# test valid domains for each variable
|
||||||
|
clf = models.BoltOnModel(n_outputs)
|
||||||
|
self.assertIsInstance(clf, models.BoltOnModel)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'invalid n_outputs',
|
||||||
|
'n_outputs': -1,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_bad_init_params(self, n_outputs):
|
||||||
|
"""test bad initializations of BoltOnModel that should raise errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
"""
|
||||||
|
# test invalid domains for each variable, especially noise
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
models.BoltOnModel(n_outputs)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'string compile',
|
||||||
|
'n_outputs': 1,
|
||||||
|
'loss': TestLoss(1, 1, 1),
|
||||||
|
'optimizer': 'adam',
|
||||||
|
},
|
||||||
|
{'testcase_name': 'test compile',
|
||||||
|
'n_outputs': 100,
|
||||||
|
'loss': TestLoss(1, 1, 1),
|
||||||
|
'optimizer': TestOptimizer(),
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_compile(self, n_outputs, loss, optimizer):
|
||||||
|
"""Test compilation of BoltOnModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
loss: instantiated TestLoss instance
|
||||||
|
optimizer: instantiated TestOptimizer instance
|
||||||
|
"""
|
||||||
|
# test compilation of valid tf.optimizer and tf.loss
|
||||||
|
with self.cached_session():
|
||||||
|
clf = models.BoltOnModel(n_outputs)
|
||||||
|
clf.compile(optimizer, loss)
|
||||||
|
self.assertEqual(clf.loss, loss)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'Not strong loss',
|
||||||
|
'n_outputs': 1,
|
||||||
|
'loss': losses.BinaryCrossentropy(),
|
||||||
|
'optimizer': 'adam',
|
||||||
|
},
|
||||||
|
{'testcase_name': 'Not valid optimizer',
|
||||||
|
'n_outputs': 1,
|
||||||
|
'loss': TestLoss(1, 1, 1),
|
||||||
|
'optimizer': 'ada',
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_bad_compile(self, n_outputs, loss, optimizer):
|
||||||
|
"""test bad compilations of BoltOnModel that should raise errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
loss: instantiated TestLoss instance
|
||||||
|
optimizer: instantiated TestOptimizer instance
|
||||||
|
"""
|
||||||
|
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
||||||
|
with self.cached_session():
|
||||||
|
with self.assertRaises((ValueError, AttributeError)):
|
||||||
|
clf = models.BoltOnModel(n_outputs)
|
||||||
|
clf.compile(optimizer, loss)
|
||||||
|
|
||||||
|
|
||||||
|
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
|
||||||
|
"""Creates a categorically encoded dataset.
|
||||||
|
|
||||||
|
Creates a categorically encoded dataset (y is categorical).
|
||||||
|
returns the specified dataset either as a static array or as a generator.
|
||||||
|
Will have evenly split samples across each output class.
|
||||||
|
Each output class will be a different point in the input space.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: number of rows
|
||||||
|
input_dim: input dimensionality
|
||||||
|
n_classes: output dimensionality
|
||||||
|
generator: False for array, True for generator
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
|
||||||
|
"""
|
||||||
|
x_stack = []
|
||||||
|
y_stack = []
|
||||||
|
for i_class in range(n_classes):
|
||||||
|
x_stack.append(
|
||||||
|
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
|
||||||
|
)
|
||||||
|
y_stack.append(
|
||||||
|
tf.constant(i_class, tf.float32, (n_samples, n_classes))
|
||||||
|
)
|
||||||
|
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
|
||||||
|
if generator:
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(
|
||||||
|
(x_set, y_set)
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
return x_set, y_set
|
||||||
|
|
||||||
|
|
||||||
|
def _do_fit(n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_outputs,
|
||||||
|
epsilon,
|
||||||
|
generator,
|
||||||
|
batch_size,
|
||||||
|
reset_n_samples,
|
||||||
|
optimizer,
|
||||||
|
loss,
|
||||||
|
distribution='laplace'):
|
||||||
|
"""Instantiate necessary components for fitting and perform a model fit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_samples: number of samples in dataset
|
||||||
|
input_dim: the sample dimensionality
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
epsilon: privacy parameter
|
||||||
|
generator: True to create a generator, False to use an iterator
|
||||||
|
batch_size: batch_size to use
|
||||||
|
reset_n_samples: True to set _samples to None prior to fitting.
|
||||||
|
False does nothing
|
||||||
|
optimizer: instance of TestOptimizer
|
||||||
|
loss: instance of TestLoss
|
||||||
|
distribution: distribution to get noise from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BoltOnModel instsance
|
||||||
|
"""
|
||||||
|
clf = models.BoltOnModel(n_outputs)
|
||||||
|
clf.compile(optimizer, loss)
|
||||||
|
if generator:
|
||||||
|
x = _cat_dataset(
|
||||||
|
n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_outputs,
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
y = None
|
||||||
|
# x = x.batch(batch_size)
|
||||||
|
x = x.shuffle(n_samples//2)
|
||||||
|
batch_size = None
|
||||||
|
else:
|
||||||
|
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
|
||||||
|
if reset_n_samples:
|
||||||
|
n_samples = None
|
||||||
|
|
||||||
|
clf.fit(x,
|
||||||
|
y,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=distribution,
|
||||||
|
epsilon=epsilon)
|
||||||
|
return clf
|
||||||
|
|
||||||
|
|
||||||
|
class FitTests(keras_parameterized.TestCase):
|
||||||
|
"""Test cases for keras model fitting."""
|
||||||
|
|
||||||
|
# @test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'iterator fit',
|
||||||
|
'generator': False,
|
||||||
|
'reset_n_samples': True,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'iterator fit no samples',
|
||||||
|
'generator': False,
|
||||||
|
'reset_n_samples': True,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'generator fit',
|
||||||
|
'generator': True,
|
||||||
|
'reset_n_samples': False,
|
||||||
|
},
|
||||||
|
{'testcase_name': 'with callbacks',
|
||||||
|
'generator': True,
|
||||||
|
'reset_n_samples': False,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_fit(self, generator, reset_n_samples):
|
||||||
|
"""Tests fitting of BoltOnModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: True for generator test, False for iterator test.
|
||||||
|
reset_n_samples: True to reset the n_samples to None, False does nothing
|
||||||
|
"""
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
optimizer = BoltOn(TestOptimizer(), loss)
|
||||||
|
n_classes = 2
|
||||||
|
input_dim = 5
|
||||||
|
epsilon = 1
|
||||||
|
batch_size = 1
|
||||||
|
n_samples = 10
|
||||||
|
clf = _do_fit(
|
||||||
|
n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_classes,
|
||||||
|
epsilon,
|
||||||
|
generator,
|
||||||
|
batch_size,
|
||||||
|
reset_n_samples,
|
||||||
|
optimizer,
|
||||||
|
loss,
|
||||||
|
)
|
||||||
|
self.assertEqual(hasattr(clf, 'layers'), True)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'generator fit',
|
||||||
|
'generator': True,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_fit_gen(self, generator):
|
||||||
|
"""Tests the fit_generator method of BoltOnModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: True to test with a generator dataset
|
||||||
|
"""
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
optimizer = TestOptimizer()
|
||||||
|
n_classes = 2
|
||||||
|
input_dim = 5
|
||||||
|
batch_size = 1
|
||||||
|
n_samples = 10
|
||||||
|
clf = models.BoltOnModel(n_classes)
|
||||||
|
clf.compile(optimizer, loss)
|
||||||
|
x = _cat_dataset(
|
||||||
|
n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_classes,
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
x = x.batch(batch_size)
|
||||||
|
x = x.shuffle(n_samples // 2)
|
||||||
|
clf.fit_generator(x, n_samples=n_samples)
|
||||||
|
self.assertEqual(hasattr(clf, 'layers'), True)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'iterator no n_samples',
|
||||||
|
'generator': True,
|
||||||
|
'reset_n_samples': True,
|
||||||
|
'distribution': 'laplace'
|
||||||
|
},
|
||||||
|
{'testcase_name': 'invalid distribution',
|
||||||
|
'generator': True,
|
||||||
|
'reset_n_samples': True,
|
||||||
|
'distribution': 'not_valid'
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def test_bad_fit(self, generator, reset_n_samples, distribution):
|
||||||
|
"""Tests fitting with invalid parameters, which should raise an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generator: True to test with generator, False is iterator
|
||||||
|
reset_n_samples: True to reset the n_samples param to None prior to
|
||||||
|
passing it to fit
|
||||||
|
distribution: distribution to get noise from.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
optimizer = TestOptimizer()
|
||||||
|
n_classes = 2
|
||||||
|
input_dim = 5
|
||||||
|
epsilon = 1
|
||||||
|
batch_size = 1
|
||||||
|
n_samples = 10
|
||||||
|
_do_fit(
|
||||||
|
n_samples,
|
||||||
|
input_dim,
|
||||||
|
n_classes,
|
||||||
|
epsilon,
|
||||||
|
generator,
|
||||||
|
batch_size,
|
||||||
|
reset_n_samples,
|
||||||
|
optimizer,
|
||||||
|
loss,
|
||||||
|
distribution
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'None class_weights',
|
||||||
|
'class_weights': None,
|
||||||
|
'class_counts': None,
|
||||||
|
'num_classes': None,
|
||||||
|
'result': 1},
|
||||||
|
{'testcase_name': 'class weights array',
|
||||||
|
'class_weights': [1, 1],
|
||||||
|
'class_counts': [1, 1],
|
||||||
|
'num_classes': 2,
|
||||||
|
'result': [1, 1]},
|
||||||
|
{'testcase_name': 'class weights balanced',
|
||||||
|
'class_weights': 'balanced',
|
||||||
|
'class_counts': [1, 1],
|
||||||
|
'num_classes': 2,
|
||||||
|
'result': [1, 1]},
|
||||||
|
])
|
||||||
|
def test_class_calculate(self,
|
||||||
|
class_weights,
|
||||||
|
class_counts,
|
||||||
|
num_classes,
|
||||||
|
result):
|
||||||
|
"""Tests the BOltonModel calculate_class_weights method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weights: the class_weights to use
|
||||||
|
class_counts: count of number of samples for each class
|
||||||
|
num_classes: number of outputs neurons
|
||||||
|
result: expected result
|
||||||
|
"""
|
||||||
|
clf = models.BoltOnModel(1, 1)
|
||||||
|
expected = clf.calculate_class_weights(class_weights,
|
||||||
|
class_counts,
|
||||||
|
num_classes)
|
||||||
|
|
||||||
|
if hasattr(expected, 'numpy'):
|
||||||
|
expected = expected.numpy()
|
||||||
|
self.assertAllEqual(
|
||||||
|
expected,
|
||||||
|
result
|
||||||
|
)
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'class weight not valid str',
|
||||||
|
'class_weights': 'not_valid',
|
||||||
|
'class_counts': 1,
|
||||||
|
'num_classes': 1,
|
||||||
|
'err_msg': 'Detected string class_weights with value: not_valid'},
|
||||||
|
{'testcase_name': 'no class counts',
|
||||||
|
'class_weights': 'balanced',
|
||||||
|
'class_counts': None,
|
||||||
|
'num_classes': 1,
|
||||||
|
'err_msg': 'Class counts must be provided if '
|
||||||
|
'using class_weights=balanced'},
|
||||||
|
{'testcase_name': 'no num classes',
|
||||||
|
'class_weights': 'balanced',
|
||||||
|
'class_counts': [1],
|
||||||
|
'num_classes': None,
|
||||||
|
'err_msg': 'num_classes must be provided if '
|
||||||
|
'using class_weights=balanced'},
|
||||||
|
{'testcase_name': 'class counts not array',
|
||||||
|
'class_weights': 'balanced',
|
||||||
|
'class_counts': 1,
|
||||||
|
'num_classes': None,
|
||||||
|
'err_msg': 'class counts must be a 1D array.'},
|
||||||
|
{'testcase_name': 'class counts array, no num classes',
|
||||||
|
'class_weights': [1],
|
||||||
|
'class_counts': None,
|
||||||
|
'num_classes': None,
|
||||||
|
'err_msg': 'You must pass a value for num_classes if '
|
||||||
|
'creating an array of class_weights'},
|
||||||
|
{'testcase_name': 'class counts array, improper shape',
|
||||||
|
'class_weights': [[1], [1]],
|
||||||
|
'class_counts': None,
|
||||||
|
'num_classes': 2,
|
||||||
|
'err_msg': 'Detected class_weights shape'},
|
||||||
|
{'testcase_name': 'class counts array, wrong number classes',
|
||||||
|
'class_weights': [1, 1, 1],
|
||||||
|
'class_counts': None,
|
||||||
|
'num_classes': 2,
|
||||||
|
'err_msg': 'Detected array length:'},
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_class_errors(self,
|
||||||
|
class_weights,
|
||||||
|
class_counts,
|
||||||
|
num_classes,
|
||||||
|
err_msg):
|
||||||
|
"""Tests the BOltonModel calculate_class_weights method.
|
||||||
|
|
||||||
|
This test passes invalid params which should raise the expected errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weights: the class_weights to use.
|
||||||
|
class_counts: count of number of samples for each class.
|
||||||
|
num_classes: number of outputs neurons.
|
||||||
|
err_msg: The expected error message.
|
||||||
|
"""
|
||||||
|
clf = models.BoltOnModel(1, 1)
|
||||||
|
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||||
|
clf.calculate_class_weights(class_weights,
|
||||||
|
class_counts,
|
||||||
|
num_classes)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
390
privacy/bolt_on/optimizers.py
Normal file
390
privacy/bolt_on/optimizers.py
Normal file
|
@ -0,0 +1,390 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""BoltOn Optimizer for Bolt-on method."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
|
||||||
|
_accepted_distributions = ['laplace'] # implemented distributions for noising
|
||||||
|
|
||||||
|
|
||||||
|
class GammaBetaDecreasingStep(
|
||||||
|
optimizer_v2.learning_rate_schedule.LearningRateSchedule):
|
||||||
|
"""Computes LR as minimum of 1/beta and 1/(gamma * step) at each step.
|
||||||
|
|
||||||
|
This is a required step for privacy guarantees.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.is_init = False
|
||||||
|
self.beta = None
|
||||||
|
self.gamma = None
|
||||||
|
|
||||||
|
def __call__(self, step):
|
||||||
|
"""Computes and returns the learning rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: the current iteration number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
|
||||||
|
the BoltOn privacy requirements.
|
||||||
|
"""
|
||||||
|
if not self.is_init:
|
||||||
|
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
|
||||||
|
'This is performed automatically by using the '
|
||||||
|
'{1} as a context manager, '
|
||||||
|
'as desired'.format(self.__class__.__name__,
|
||||||
|
BoltOn.__class__.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dtype = self.beta.dtype
|
||||||
|
one = tf.constant(1, dtype)
|
||||||
|
return tf.math.minimum(tf.math.reduce_min(one/self.beta),
|
||||||
|
one/(self.gamma*math_ops.cast(step, dtype))
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Return config to setup the learning rate scheduler."""
|
||||||
|
return {'beta': self.beta, 'gamma': self.gamma}
|
||||||
|
|
||||||
|
def initialize(self, beta, gamma):
|
||||||
|
"""Setups scheduler with beta and gamma values from the loss function.
|
||||||
|
|
||||||
|
Meant to be used with .fit as the loss params may depend on values passed to
|
||||||
|
fit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
beta: Smoothness value. See StrongConvexMixin
|
||||||
|
gamma: Strong Convexity parameter. See StrongConvexMixin.
|
||||||
|
"""
|
||||||
|
self.is_init = True
|
||||||
|
self.beta = beta
|
||||||
|
self.gamma = gamma
|
||||||
|
|
||||||
|
def de_initialize(self):
|
||||||
|
"""De initialize post fit, as another fit call may use other parameters."""
|
||||||
|
self.is_init = False
|
||||||
|
self.beta = None
|
||||||
|
self.gamma = None
|
||||||
|
|
||||||
|
|
||||||
|
class BoltOn(optimizer_v2.OptimizerV2):
|
||||||
|
"""Wrap another tf optimizer with BoltOn privacy protocol.
|
||||||
|
|
||||||
|
BoltOn optimizer wraps another tf optimizer to be used
|
||||||
|
as the visible optimizer to the tf model. No matter the optimizer
|
||||||
|
passed, "BoltOn" enables the bolt-on model to control the learning rate
|
||||||
|
based on the strongly convex loss.
|
||||||
|
|
||||||
|
To use the BoltOn method, you must:
|
||||||
|
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
|
||||||
|
2. use it as a context manager around your .fit method internals.
|
||||||
|
|
||||||
|
This can be accomplished by the following:
|
||||||
|
optimizer = tf.optimizers.SGD()
|
||||||
|
loss = privacy.bolt_on.losses.StrongConvexBinaryCrossentropy()
|
||||||
|
bolton = BoltOn(optimizer, loss)
|
||||||
|
with bolton(*args) as _:
|
||||||
|
model.fit()
|
||||||
|
The args required for the context manager can be found in the __call__
|
||||||
|
method.
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et. al.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, # pylint: disable=super-init-not-called
|
||||||
|
optimizer,
|
||||||
|
loss,
|
||||||
|
dtype=tf.float32,
|
||||||
|
):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: Optimizer_v2 or subclass to be used as the optimizer
|
||||||
|
(wrapped).
|
||||||
|
loss: StrongConvexLoss function that the model is being compiled with.
|
||||||
|
dtype: dtype
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not isinstance(loss, StrongConvexMixin):
|
||||||
|
raise ValueError('loss function must be a Strongly Convex and therefore '
|
||||||
|
'extend the StrongConvexMixin.')
|
||||||
|
self._private_attributes = ['_internal_optimizer',
|
||||||
|
'dtype',
|
||||||
|
'noise_distribution',
|
||||||
|
'epsilon',
|
||||||
|
'loss',
|
||||||
|
'class_weights',
|
||||||
|
'input_dim',
|
||||||
|
'n_samples',
|
||||||
|
'layers',
|
||||||
|
'batch_size',
|
||||||
|
'_is_init'
|
||||||
|
]
|
||||||
|
self._internal_optimizer = optimizer
|
||||||
|
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
|
||||||
|
# rate scheduler, as required for privacy guarantees. This will still need
|
||||||
|
# to get values from the loss function near the time that .fit is called
|
||||||
|
# on the model (when this optimizer will be called as a context manager)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.loss = loss
|
||||||
|
self._is_init = False
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer.get_config()
|
||||||
|
|
||||||
|
def project_weights_to_r(self, force=False):
|
||||||
|
"""Normalize the weights to the R-ball.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: True to normalize regardless of previous weight values.
|
||||||
|
False to check if weights > R-ball and only normalize then.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If not called from inside this optimizer context.
|
||||||
|
"""
|
||||||
|
if not self._is_init:
|
||||||
|
raise Exception('This method must be called from within the optimizer\'s '
|
||||||
|
'context.')
|
||||||
|
radius = self.loss.radius()
|
||||||
|
for layer in self.layers:
|
||||||
|
weight_norm = tf.norm(layer.kernel, axis=0)
|
||||||
|
if force:
|
||||||
|
layer.kernel = layer.kernel / (weight_norm / radius)
|
||||||
|
else:
|
||||||
|
layer.kernel = tf.cond(
|
||||||
|
tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0,
|
||||||
|
lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop
|
||||||
|
lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_noise(self, input_dim, output_dim):
|
||||||
|
"""Sample noise to be added to weights for privacy guarantee.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim: the input dimensionality for the weights
|
||||||
|
output_dim: the output dimensionality for the weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Noise in shape of layer's weights to be added to the weights.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If not called from inside this optimizer's context.
|
||||||
|
"""
|
||||||
|
if not self._is_init:
|
||||||
|
raise Exception('This method must be called from within the optimizer\'s '
|
||||||
|
'context.')
|
||||||
|
loss = self.loss
|
||||||
|
distribution = self.noise_distribution.lower()
|
||||||
|
if distribution == _accepted_distributions[0]: # laplace
|
||||||
|
per_class_epsilon = self.epsilon / (output_dim)
|
||||||
|
l2_sensitivity = (2 *
|
||||||
|
loss.lipchitz_constant(self.class_weights)) / \
|
||||||
|
(loss.gamma() * self.n_samples * self.batch_size)
|
||||||
|
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
|
||||||
|
mean=0,
|
||||||
|
seed=1,
|
||||||
|
stddev=1.0,
|
||||||
|
dtype=self.dtype)
|
||||||
|
unit_vector = unit_vector / tf.math.sqrt(
|
||||||
|
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
beta = l2_sensitivity / per_class_epsilon
|
||||||
|
alpha = input_dim # input_dim
|
||||||
|
gamma = tf.random.gamma([output_dim],
|
||||||
|
alpha,
|
||||||
|
beta=1 / beta,
|
||||||
|
seed=1,
|
||||||
|
dtype=self.dtype
|
||||||
|
)
|
||||||
|
return unit_vector * gamma
|
||||||
|
raise NotImplementedError('Noise distribution: {0} is not '
|
||||||
|
'a valid distribution'.format(distribution))
|
||||||
|
|
||||||
|
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer.from_config(*args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
"""Get attr.
|
||||||
|
|
||||||
|
return _internal_optimizer off self instance, and everything else
|
||||||
|
from the _internal_optimizer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of attribute to get from this or aggregate optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
attribute from BoltOn if specified to come from self, else
|
||||||
|
from _internal_optimizer.
|
||||||
|
"""
|
||||||
|
if name == '_private_attributes' or name in self._private_attributes:
|
||||||
|
return getattr(self, name)
|
||||||
|
optim = object.__getattribute__(self, '_internal_optimizer')
|
||||||
|
try:
|
||||||
|
return object.__getattribute__(optim, name)
|
||||||
|
except AttributeError:
|
||||||
|
raise AttributeError(
|
||||||
|
"Neither '{0}' nor '{1}' object has attribute '{2}'"
|
||||||
|
"".format(self.__class__.__name__,
|
||||||
|
self._internal_optimizer.__class__.__name__,
|
||||||
|
name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
"""Set attribute to self instance if its the internal optimizer.
|
||||||
|
|
||||||
|
Reroute everything else to the _internal_optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: attribute name
|
||||||
|
value: attribute value
|
||||||
|
"""
|
||||||
|
if key == '_private_attributes':
|
||||||
|
object.__setattr__(self, key, value)
|
||||||
|
elif key in self._private_attributes:
|
||||||
|
object.__setattr__(self, key, value)
|
||||||
|
else:
|
||||||
|
setattr(self._internal_optimizer, key, value)
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def get_updates(self, loss, params):
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
out = self._internal_optimizer.get_updates(loss, params)
|
||||||
|
self.project_weights_to_r()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
|
||||||
|
self.project_weights_to_r()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
out = self._internal_optimizer.minimize(*args, **kwargs)
|
||||||
|
self.project_weights_to_r()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
|
||||||
|
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
|
||||||
|
return self._internal_optimizer.get_gradients(*args, **kwargs)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""Context manager call at the beginning of with statement.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self, to be used in context manager
|
||||||
|
"""
|
||||||
|
self._is_init = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __call__(self,
|
||||||
|
noise_distribution,
|
||||||
|
epsilon,
|
||||||
|
layers,
|
||||||
|
class_weights,
|
||||||
|
n_samples,
|
||||||
|
batch_size
|
||||||
|
):
|
||||||
|
"""Accepts required values for bolton method from context entry point.
|
||||||
|
|
||||||
|
Stores them on the optimizer for use throughout fitting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise_distribution: the noise distribution to pick.
|
||||||
|
see _accepted_distributions and get_noise for possible values.
|
||||||
|
epsilon: privacy parameter. Lower gives more privacy but less utility.
|
||||||
|
layers: list of Keras/Tensorflow layers. Can be found as model.layers
|
||||||
|
class_weights: class_weights used, which may either be a scalar or 1D
|
||||||
|
tensor with dim == n_classes.
|
||||||
|
n_samples: number of rows/individual samples in the training set
|
||||||
|
batch_size: batch size used.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
self, to be used by the __enter__ method for context.
|
||||||
|
"""
|
||||||
|
if epsilon <= 0:
|
||||||
|
raise ValueError('Detected epsilon: {0}. '
|
||||||
|
'Valid range is 0 < epsilon <inf'.format(epsilon))
|
||||||
|
if noise_distribution not in _accepted_distributions:
|
||||||
|
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
|
||||||
|
'distributions'.format(noise_distribution,
|
||||||
|
_accepted_distributions))
|
||||||
|
self.noise_distribution = noise_distribution
|
||||||
|
self.learning_rate.initialize(self.loss.beta(class_weights),
|
||||||
|
self.loss.gamma()
|
||||||
|
)
|
||||||
|
self.epsilon = tf.constant(epsilon, dtype=self.dtype)
|
||||||
|
self.class_weights = tf.constant(class_weights, dtype=self.dtype)
|
||||||
|
self.n_samples = tf.constant(n_samples, dtype=self.dtype)
|
||||||
|
self.layers = layers
|
||||||
|
self.batch_size = tf.constant(batch_size, dtype=self.dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
"""Exit call from with statement.
|
||||||
|
|
||||||
|
Used to:
|
||||||
|
1.reset the model and fit parameters passed to the optimizer
|
||||||
|
to enable the BoltOn Privacy guarantees. These are reset to ensure
|
||||||
|
that any future calls to fit with the same instance of the optimizer
|
||||||
|
will properly error out.
|
||||||
|
|
||||||
|
2.call post-fit methods normalizing/projecting the model weights and
|
||||||
|
adding noise to the weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*args: encompasses the type, value, and traceback values which are unused.
|
||||||
|
"""
|
||||||
|
self.project_weights_to_r(True)
|
||||||
|
for layer in self.layers:
|
||||||
|
input_dim = layer.kernel.shape[0]
|
||||||
|
output_dim = layer.units
|
||||||
|
noise = self.get_noise(input_dim,
|
||||||
|
output_dim,
|
||||||
|
)
|
||||||
|
layer.kernel = tf.math.add(layer.kernel, noise)
|
||||||
|
self.noise_distribution = None
|
||||||
|
self.learning_rate.de_initialize()
|
||||||
|
self.epsilon = -1
|
||||||
|
self.batch_size = -1
|
||||||
|
self.class_weights = None
|
||||||
|
self.n_samples = None
|
||||||
|
self.input_dim = None
|
||||||
|
self.layers = None
|
||||||
|
self._is_init = False
|
579
privacy/bolt_on/optimizers_test.py
Normal file
579
privacy/bolt_on/optimizers_test.py
Normal file
|
@ -0,0 +1,579 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Unit testing for optimizers."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.python import ops as _ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras import keras_parameterized
|
||||||
|
from tensorflow.python.keras import losses
|
||||||
|
from tensorflow.python.keras.initializers import constant
|
||||||
|
from tensorflow.python.keras.models import Model
|
||||||
|
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
|
||||||
|
from tensorflow.python.keras.regularizers import L1L2
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from privacy.bolt_on import optimizers as opt
|
||||||
|
from privacy.bolt_on.losses import StrongConvexMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(Model): # pylint: disable=abstract-method
|
||||||
|
"""BoltOn episilon-delta model.
|
||||||
|
|
||||||
|
Uses 4 key steps to achieve privacy guarantees:
|
||||||
|
1. Adds noise to weights after training (output perturbation).
|
||||||
|
2. Projects weights to R after each batch
|
||||||
|
3. Limits learning rate
|
||||||
|
4. Use a strongly convex loss function (see compile)
|
||||||
|
|
||||||
|
For more details on the strong convexity requirements, see:
|
||||||
|
Bolt-on Differential Privacy for Scalable Stochastic Gradient
|
||||||
|
Descent-based Analytics by Xi Wu et. al.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_outputs: number of output neurons
|
||||||
|
input_shape:
|
||||||
|
init_value:
|
||||||
|
"""
|
||||||
|
super(TestModel, self).__init__(name='bolton', dynamic=False)
|
||||||
|
self.n_outputs = n_outputs
|
||||||
|
self.layer_input_shape = input_shape
|
||||||
|
self.output_layer = tf.keras.layers.Dense(
|
||||||
|
self.n_outputs,
|
||||||
|
input_shape=self.layer_input_shape,
|
||||||
|
kernel_regularizer=L1L2(l2=1),
|
||||||
|
kernel_initializer=constant(init_value),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoss(losses.Loss, StrongConvexMixin):
|
||||||
|
"""Test loss function for testing BoltOn model."""
|
||||||
|
|
||||||
|
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
|
||||||
|
super(TestLoss, self).__init__(name=name)
|
||||||
|
self.reg_lambda = reg_lambda
|
||||||
|
self.C = c_arg # pylint: disable=invalid-name
|
||||||
|
self.radius_constant = radius_constant
|
||||||
|
|
||||||
|
def radius(self):
|
||||||
|
"""Radius, R, of the hypothesis space W.
|
||||||
|
|
||||||
|
W is a convex set that forms the hypothesis space.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a tensor
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
|
||||||
|
|
||||||
|
def gamma(self):
|
||||||
|
"""Returns strongly convex parameter, gamma."""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def beta(self, class_weight): # pylint: disable=unused-argument
|
||||||
|
"""Smoothness, beta.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: the class weights as scalar or 1d tensor, where its
|
||||||
|
dimensionality is equal to the number of outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Beta
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
|
||||||
|
"""Lipchitz constant, L.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
constant L
|
||||||
|
"""
|
||||||
|
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def call(self, y_true, y_pred):
|
||||||
|
"""Loss function that is minimized at the mean of the input points."""
|
||||||
|
return 0.5 * tf.reduce_sum(
|
||||||
|
tf.math.squared_difference(y_true, y_pred),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_class_weight(self, class_weight, dtype=tf.float32):
|
||||||
|
"""the maximum weighting in class weights (max value) as a scalar tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
class_weight: class weights used
|
||||||
|
dtype: the data type for tensor conversions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
maximum class weighting as tensor scalar
|
||||||
|
"""
|
||||||
|
if class_weight is None:
|
||||||
|
return 1
|
||||||
|
raise NotImplementedError('')
|
||||||
|
|
||||||
|
def kernel_regularizer(self):
|
||||||
|
"""Returns the kernel_regularizer to be used.
|
||||||
|
|
||||||
|
Any subclass should override this method if they want a kernel_regularizer
|
||||||
|
(if required for the loss function to be StronglyConvex.
|
||||||
|
"""
|
||||||
|
return L1L2(l2=self.reg_lambda)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOptimizer(OptimizerV2):
|
||||||
|
"""Optimizer used for testing the BoltOn optimizer."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(TestOptimizer, self).__init__('test')
|
||||||
|
self.not_private = 'test'
|
||||||
|
self.iterations = tf.constant(1, dtype=tf.float32)
|
||||||
|
self._iterations = tf.constant(1, dtype=tf.float32)
|
||||||
|
|
||||||
|
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def from_config(self, config, custom_objects=None):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def _create_slots(self):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def _resource_apply_dense(self, grad, handle):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def _resource_apply_sparse(self, grad, handle, indices):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def get_updates(self, loss, params):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def apply_gradients(self, grads_and_vars, name=None):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def minimize(self, loss, var_list, grad_loss=None, name=None):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def get_gradients(self, loss, params):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
def limit_learning_rate(self):
|
||||||
|
return 'test'
|
||||||
|
|
||||||
|
|
||||||
|
class BoltonOptimizerTest(keras_parameterized.TestCase):
|
||||||
|
"""BoltOn Optimizer tests."""
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'getattr',
|
||||||
|
'fn': '__getattr__',
|
||||||
|
'args': ['dtype'],
|
||||||
|
'result': tf.float32,
|
||||||
|
'test_attr': None},
|
||||||
|
{'testcase_name': 'project_weights_to_r',
|
||||||
|
'fn': 'project_weights_to_r',
|
||||||
|
'args': ['dtype'],
|
||||||
|
'result': None,
|
||||||
|
'test_attr': ''},
|
||||||
|
])
|
||||||
|
|
||||||
|
def test_fn(self, fn, args, result, test_attr):
|
||||||
|
"""test that a fn of BoltOn optimizer is working as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: method of Optimizer to test
|
||||||
|
args: args to optimizer fn
|
||||||
|
result: the expected result
|
||||||
|
test_attr: None if the fn returns the test result. Otherwise, this is
|
||||||
|
the attribute of BoltOn to check against result with.
|
||||||
|
|
||||||
|
"""
|
||||||
|
tf.random.set_seed(1)
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(1)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
bolton._is_init = True # pylint: disable=protected-access
|
||||||
|
bolton.layers = model.layers
|
||||||
|
bolton.epsilon = 2
|
||||||
|
bolton.noise_distribution = 'laplace'
|
||||||
|
bolton.n_outputs = 1
|
||||||
|
bolton.n_samples = 1
|
||||||
|
res = getattr(bolton, fn, None)(*args)
|
||||||
|
if test_attr is not None:
|
||||||
|
res = getattr(bolton, test_attr, None)
|
||||||
|
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
|
||||||
|
res = res.numpy()
|
||||||
|
result = result.numpy()
|
||||||
|
self.assertEqual(res, result)
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': '1 value project to r=1',
|
||||||
|
'r': 1,
|
||||||
|
'init_value': 2,
|
||||||
|
'shape': (1,),
|
||||||
|
'n_out': 1,
|
||||||
|
'result': [[1]]},
|
||||||
|
{'testcase_name': '2 value project to r=1',
|
||||||
|
'r': 1,
|
||||||
|
'init_value': 2,
|
||||||
|
'shape': (2,),
|
||||||
|
'n_out': 1,
|
||||||
|
'result': [[0.707107], [0.707107]]},
|
||||||
|
{'testcase_name': '1 value project to r=2',
|
||||||
|
'r': 2,
|
||||||
|
'init_value': 3,
|
||||||
|
'shape': (1,),
|
||||||
|
'n_out': 1,
|
||||||
|
'result': [[2]]},
|
||||||
|
{'testcase_name': 'no project',
|
||||||
|
'r': 2,
|
||||||
|
'init_value': 1,
|
||||||
|
'shape': (1,),
|
||||||
|
'n_out': 1,
|
||||||
|
'result': [[1]]},
|
||||||
|
])
|
||||||
|
def test_project(self, r, shape, n_out, init_value, result):
|
||||||
|
"""test that a fn of BoltOn optimizer is working as expected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
r: Radius value for StrongConvex loss function.
|
||||||
|
shape: input_dimensionality
|
||||||
|
n_out: output dimensionality
|
||||||
|
init_value: the initial value for 'constant' kernel initializer
|
||||||
|
result: the expected output after projection.
|
||||||
|
"""
|
||||||
|
tf.random.set_seed(1)
|
||||||
|
@tf.function
|
||||||
|
def project_fn(r):
|
||||||
|
loss = TestLoss(1, 1, r)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(n_out, shape, init_value)
|
||||||
|
model.compile(bolton, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
bolton._is_init = True # pylint: disable=protected-access
|
||||||
|
bolton.layers = model.layers
|
||||||
|
bolton.epsilon = 2
|
||||||
|
bolton.noise_distribution = 'laplace'
|
||||||
|
bolton.n_outputs = 1
|
||||||
|
bolton.n_samples = 1
|
||||||
|
bolton.project_weights_to_r()
|
||||||
|
return _ops.convert_to_tensor_v2(bolton.layers[0].kernel, tf.float32)
|
||||||
|
res = project_fn(r)
|
||||||
|
self.assertAllClose(res, result)
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'normal values',
|
||||||
|
'epsilon': 2,
|
||||||
|
'noise': 'laplace',
|
||||||
|
'class_weights': 1},
|
||||||
|
])
|
||||||
|
def test_context_manager(self, noise, epsilon, class_weights):
|
||||||
|
"""Tests the context manager functionality of the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise: noise distribution to pick
|
||||||
|
epsilon: epsilon privacy parameter to use
|
||||||
|
class_weights: class_weights to use
|
||||||
|
"""
|
||||||
|
@tf.function
|
||||||
|
def test_run():
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(1, (1,), 1)
|
||||||
|
model.compile(bolton, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
with bolton(noise, epsilon, model.layers, class_weights, 1, 1) as _:
|
||||||
|
pass
|
||||||
|
return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32)
|
||||||
|
epsilon = test_run()
|
||||||
|
self.assertEqual(epsilon.numpy(), -1)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'invalid noise',
|
||||||
|
'epsilon': 1,
|
||||||
|
'noise': 'not_valid',
|
||||||
|
'err_msg': 'Detected noise distribution: not_valid not one of:'},
|
||||||
|
{'testcase_name': 'invalid epsilon',
|
||||||
|
'epsilon': -1,
|
||||||
|
'noise': 'laplace',
|
||||||
|
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
|
||||||
|
])
|
||||||
|
def test_context_domains(self, noise, epsilon, err_msg):
|
||||||
|
"""Tests the context domains.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
noise: noise distribution to pick
|
||||||
|
epsilon: epsilon privacy parameter to use
|
||||||
|
err_msg: the expected error message
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@tf.function
|
||||||
|
def test_run(noise, epsilon):
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(1, (1,), 1)
|
||||||
|
model.compile(bolton, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
with bolton(noise, epsilon, model.layers, 1, 1, 1) as _:
|
||||||
|
pass
|
||||||
|
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||||
|
test_run(noise, epsilon)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'fn: get_noise',
|
||||||
|
'fn': 'get_noise',
|
||||||
|
'args': [1, 1],
|
||||||
|
'err_msg': 'ust be called from within the optimizer\'s context'},
|
||||||
|
])
|
||||||
|
def test_not_in_context(self, fn, args, err_msg):
|
||||||
|
"""Tests that the expected functions raise errors when not in context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: the function to test
|
||||||
|
args: the arguments for said function
|
||||||
|
err_msg: expected error message
|
||||||
|
"""
|
||||||
|
@tf.function
|
||||||
|
def test_run(fn, args):
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(1, (1,), 1)
|
||||||
|
model.compile(bolton, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
getattr(bolton, fn)(*args)
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
|
||||||
|
test_run(fn, args)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'fn: get_updates',
|
||||||
|
'fn': 'get_updates',
|
||||||
|
'args': [0, 0]},
|
||||||
|
{'testcase_name': 'fn: get_config',
|
||||||
|
'fn': 'get_config',
|
||||||
|
'args': []},
|
||||||
|
{'testcase_name': 'fn: from_config',
|
||||||
|
'fn': 'from_config',
|
||||||
|
'args': [0]},
|
||||||
|
{'testcase_name': 'fn: _resource_apply_dense',
|
||||||
|
'fn': '_resource_apply_dense',
|
||||||
|
'args': [1, 1]},
|
||||||
|
{'testcase_name': 'fn: _resource_apply_sparse',
|
||||||
|
'fn': '_resource_apply_sparse',
|
||||||
|
'args': [1, 1, 1]},
|
||||||
|
{'testcase_name': 'fn: apply_gradients',
|
||||||
|
'fn': 'apply_gradients',
|
||||||
|
'args': [1]},
|
||||||
|
{'testcase_name': 'fn: minimize',
|
||||||
|
'fn': 'minimize',
|
||||||
|
'args': [1, 1]},
|
||||||
|
{'testcase_name': 'fn: _compute_gradients',
|
||||||
|
'fn': '_compute_gradients',
|
||||||
|
'args': [1, 1]},
|
||||||
|
{'testcase_name': 'fn: get_gradients',
|
||||||
|
'fn': 'get_gradients',
|
||||||
|
'args': [1, 1]},
|
||||||
|
])
|
||||||
|
def test_rerouted_function(self, fn, args):
|
||||||
|
"""Tests rerouted function.
|
||||||
|
|
||||||
|
Tests that a method of the internal optimizer is correctly routed from
|
||||||
|
the BoltOn instance to the internal optimizer instance (TestOptimizer,
|
||||||
|
here).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: fn to test
|
||||||
|
args: arguments to that fn
|
||||||
|
"""
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
optimizer = TestOptimizer()
|
||||||
|
bolton = opt.BoltOn(optimizer, loss)
|
||||||
|
model = TestModel(3)
|
||||||
|
model.compile(optimizer, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
bolton._is_init = True # pylint: disable=protected-access
|
||||||
|
bolton.layers = model.layers
|
||||||
|
bolton.epsilon = 2
|
||||||
|
bolton.noise_distribution = 'laplace'
|
||||||
|
bolton.n_outputs = 1
|
||||||
|
bolton.n_samples = 1
|
||||||
|
self.assertEqual(
|
||||||
|
getattr(bolton, fn, lambda: 'fn not found')(*args),
|
||||||
|
'test'
|
||||||
|
)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'fn: project_weights_to_r',
|
||||||
|
'fn': 'project_weights_to_r',
|
||||||
|
'args': []},
|
||||||
|
{'testcase_name': 'fn: get_noise',
|
||||||
|
'fn': 'get_noise',
|
||||||
|
'args': [1, 1]},
|
||||||
|
])
|
||||||
|
def test_not_reroute_fn(self, fn, args):
|
||||||
|
"""Test function is not rerouted.
|
||||||
|
|
||||||
|
Test that a fn that should not be rerouted to the internal optimizer is
|
||||||
|
in fact not rerouted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: fn to test
|
||||||
|
args: arguments to that fn
|
||||||
|
"""
|
||||||
|
@tf.function
|
||||||
|
def test_run(fn, args):
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
bolton = opt.BoltOn(TestOptimizer(), loss)
|
||||||
|
model = TestModel(1, (1,), 1)
|
||||||
|
model.compile(bolton, loss)
|
||||||
|
model.layers[0].kernel = \
|
||||||
|
model.layers[0].kernel_initializer((model.layer_input_shape[0],
|
||||||
|
model.n_outputs))
|
||||||
|
bolton._is_init = True # pylint: disable=protected-access
|
||||||
|
bolton.noise_distribution = 'laplace'
|
||||||
|
bolton.epsilon = 1
|
||||||
|
bolton.layers = model.layers
|
||||||
|
bolton.class_weights = 1
|
||||||
|
bolton.n_samples = 1
|
||||||
|
bolton.batch_size = 1
|
||||||
|
bolton.n_outputs = 1
|
||||||
|
res = getattr(bolton, fn, lambda: 'test')(*args)
|
||||||
|
if res != 'test':
|
||||||
|
res = 1
|
||||||
|
else:
|
||||||
|
res = 0
|
||||||
|
return _ops.convert_to_tensor_v2(res, dtype=tf.float32)
|
||||||
|
self.assertNotEqual(test_run(fn, args), 0)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'attr: _iterations',
|
||||||
|
'attr': '_iterations'}
|
||||||
|
])
|
||||||
|
def test_reroute_attr(self, attr):
|
||||||
|
"""Test a function is rerouted.
|
||||||
|
|
||||||
|
Test that attribute of internal optimizer is correctly rerouted to the
|
||||||
|
internal optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr: attribute to test
|
||||||
|
"""
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
internal_optimizer = TestOptimizer()
|
||||||
|
optimizer = opt.BoltOn(internal_optimizer, loss)
|
||||||
|
self.assertEqual(getattr(optimizer, attr),
|
||||||
|
getattr(internal_optimizer, attr))
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'attr does not exist',
|
||||||
|
'attr': '_not_valid'}
|
||||||
|
])
|
||||||
|
def test_attribute_error(self, attr):
|
||||||
|
"""Test rerouting of attributes.
|
||||||
|
|
||||||
|
Test that attribute of internal optimizer is correctly rerouted to the
|
||||||
|
internal optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attr: attribute to test
|
||||||
|
"""
|
||||||
|
loss = TestLoss(1, 1, 1)
|
||||||
|
internal_optimizer = TestOptimizer()
|
||||||
|
optimizer = opt.BoltOn(internal_optimizer, loss)
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
getattr(optimizer, attr)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerTest(keras_parameterized.TestCase):
|
||||||
|
"""GammaBeta Scheduler tests."""
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'not in context',
|
||||||
|
'err_msg': 'Please initialize the GammaBetaDecreasingStep Learning Rate'
|
||||||
|
' Scheduler'
|
||||||
|
}
|
||||||
|
])
|
||||||
|
def test_bad_call(self, err_msg):
|
||||||
|
"""Test attribute of internal opt correctly rerouted to the internal opt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
err_msg: The expected error message from the scheduler bad call.
|
||||||
|
"""
|
||||||
|
scheduler = opt.GammaBetaDecreasingStep()
|
||||||
|
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
|
||||||
|
scheduler(1)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
{'testcase_name': 'step 1',
|
||||||
|
'step': 1,
|
||||||
|
'res': 0.5},
|
||||||
|
{'testcase_name': 'step 2',
|
||||||
|
'step': 2,
|
||||||
|
'res': 0.5},
|
||||||
|
{'testcase_name': 'step 3',
|
||||||
|
'step': 3,
|
||||||
|
'res': 0.333333333},
|
||||||
|
])
|
||||||
|
def test_call(self, step, res):
|
||||||
|
"""Test call.
|
||||||
|
|
||||||
|
Test that attribute of internal optimizer is correctly rerouted to the
|
||||||
|
internal optimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
|
||||||
|
res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
|
||||||
|
"""
|
||||||
|
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
|
||||||
|
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
|
||||||
|
scheduler = opt.GammaBetaDecreasingStep()
|
||||||
|
scheduler.initialize(beta, gamma)
|
||||||
|
step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
|
||||||
|
lr = scheduler(step)
|
||||||
|
self.assertAllClose(lr.numpy(), res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
188
tutorials/bolton_tutorial.py
Normal file
188
tutorials/bolton_tutorial.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright 2019, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tutorial for bolt_on module, the model and the optimizer."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import tensorflow as tf # pylint: disable=wrong-import-position
|
||||||
|
from privacy.bolt_on import losses # pylint: disable=wrong-import-position
|
||||||
|
from privacy.bolt_on import models # pylint: disable=wrong-import-position
|
||||||
|
from privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position
|
||||||
|
# -------
|
||||||
|
# First, we will create a binary classification dataset with a single output
|
||||||
|
# dimension. The samples for each label are repeated data points at different
|
||||||
|
# points in space.
|
||||||
|
# -------
|
||||||
|
# Parameters for dataset
|
||||||
|
n_samples = 10
|
||||||
|
input_dim = 2
|
||||||
|
n_outputs = 1
|
||||||
|
# Create binary classification dataset:
|
||||||
|
x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)),
|
||||||
|
tf.constant(1, tf.float32, (n_samples, input_dim))]
|
||||||
|
y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),
|
||||||
|
tf.constant(1, tf.float32, (n_samples, 1))]
|
||||||
|
x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)
|
||||||
|
print(x.shape, y.shape)
|
||||||
|
generator = tf.data.Dataset.from_tensor_slices((x, y))
|
||||||
|
generator = generator.batch(10)
|
||||||
|
generator = generator.shuffle(10)
|
||||||
|
# -------
|
||||||
|
# First, we will explore using the pre - built BoltOnModel, which is a thin
|
||||||
|
# wrapper around a Keras Model using a single - layer neural network.
|
||||||
|
# It automatically uses the BoltOn Optimizer which encompasses all the logic
|
||||||
|
# required for the BoltOn Differential Privacy method.
|
||||||
|
# -------
|
||||||
|
bolt = models.BoltOnModel(n_outputs) # tell the model how many outputs we have.
|
||||||
|
# -------
|
||||||
|
# Now, we will pick our optimizer and Strongly Convex Loss function. The loss
|
||||||
|
# must extend from StrongConvexMixin and implement the associated methods.Some
|
||||||
|
# existing loss functions are pre - implemented in bolt_on.loss
|
||||||
|
# -------
|
||||||
|
optimizer = tf.optimizers.SGD()
|
||||||
|
reg_lambda = 1
|
||||||
|
C = 1
|
||||||
|
radius_constant = 1
|
||||||
|
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
# -------
|
||||||
|
# For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy
|
||||||
|
# to be 1; these are all tunable and their impact can be read in losses.
|
||||||
|
# StrongConvexBinaryCrossentropy.We then compile the model with the chosen
|
||||||
|
# optimizer and loss, which will automatically wrap the chosen optimizer with
|
||||||
|
# the BoltOn Optimizer, ensuring the required components function as required
|
||||||
|
# for privacy guarantees.
|
||||||
|
# -------
|
||||||
|
bolt.compile(optimizer, loss)
|
||||||
|
# -------
|
||||||
|
# To fit the model, the optimizer will require additional information about
|
||||||
|
# the dataset and model.These parameters are:
|
||||||
|
# 1. the class_weights used
|
||||||
|
# 2. the number of samples in the dataset
|
||||||
|
# 3. the batch size which the model will try to infer, if possible. If not,
|
||||||
|
# you will be required to pass these explicitly to the fit method.
|
||||||
|
#
|
||||||
|
# As well, there are two privacy parameters than can be altered:
|
||||||
|
# 1. epsilon, a float
|
||||||
|
# 2. noise_distribution, a valid string indicating the distriution to use (must
|
||||||
|
# be implemented)
|
||||||
|
#
|
||||||
|
# The BoltOnModel offers a helper method,.calculate_class_weight to aid in
|
||||||
|
# class_weight calculation.
|
||||||
|
# required parameters
|
||||||
|
# -------
|
||||||
|
class_weight = None # default, use .calculate_class_weight for other values
|
||||||
|
batch_size = None # default, if it cannot be inferred, specify this
|
||||||
|
n_samples = None # default, if it cannot be iferred, specify this
|
||||||
|
# privacy parameters
|
||||||
|
epsilon = 2
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
|
||||||
|
bolt.fit(x,
|
||||||
|
y,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
epochs=2)
|
||||||
|
# -------
|
||||||
|
# We may also train a generator object, or try different optimizers and loss
|
||||||
|
# functions. Below, we will see that we must pass the number of samples as the
|
||||||
|
# fit method is unable to infer it for a generator.
|
||||||
|
# -------
|
||||||
|
optimizer2 = tf.optimizers.Adam()
|
||||||
|
bolt.compile(optimizer2, loss)
|
||||||
|
# required parameters
|
||||||
|
class_weight = None # default, use .calculate_class_weight for other values
|
||||||
|
batch_size = None # default, if it cannot be inferred, specify this
|
||||||
|
n_samples = None # default, if it cannot be iferred, specify this
|
||||||
|
# privacy parameters
|
||||||
|
epsilon = 2
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
try:
|
||||||
|
bolt.fit(generator,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
verbose=0)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
# -------
|
||||||
|
# And now, re running with the parameter set.
|
||||||
|
# -------
|
||||||
|
n_samples = 20
|
||||||
|
bolt.fit(generator,
|
||||||
|
epsilon=epsilon,
|
||||||
|
class_weight=class_weight,
|
||||||
|
batch_size=batch_size,
|
||||||
|
n_samples=n_samples,
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
verbose=0)
|
||||||
|
# -------
|
||||||
|
# You don't have to use the BoltOn model to use the BoltOn method.
|
||||||
|
# There are only a few requirements:
|
||||||
|
# 1. make sure any requirements from the loss are implemented in the model.
|
||||||
|
# 2. instantiate the optimizer and use it as a context around the fit operation.
|
||||||
|
# -------
|
||||||
|
# -------------------- Part 2, using the Optimizer
|
||||||
|
|
||||||
|
# -------
|
||||||
|
# Here, we create our own model and setup the BoltOn optimizer.
|
||||||
|
# -------
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(tf.keras.Model): # pylint: disable=abstract-method
|
||||||
|
|
||||||
|
def __init__(self, reg_layer, number_of_outputs=1):
|
||||||
|
super(TestModel, self).__init__(name='test')
|
||||||
|
self.output_layer = tf.keras.layers.Dense(number_of_outputs,
|
||||||
|
kernel_regularizer=reg_layer)
|
||||||
|
|
||||||
|
def call(self, inputs): # pylint: disable=arguments-differ
|
||||||
|
return self.output_layer(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = tf.optimizers.SGD()
|
||||||
|
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
|
||||||
|
optimizer = BoltOn(optimizer, loss)
|
||||||
|
# -------
|
||||||
|
# Now, we instantiate our model and check for 1. Since our loss requires L2
|
||||||
|
# regularization over the kernel, we will pass it to the model.
|
||||||
|
# -------
|
||||||
|
n_outputs = 1 # parameter for model and optimizer context.
|
||||||
|
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
|
||||||
|
test_model.compile(optimizer, loss)
|
||||||
|
# -------
|
||||||
|
# We comply with 2., and use the BoltOn Optimizer as a context around the fit
|
||||||
|
# method.
|
||||||
|
# -------
|
||||||
|
# parameters for context
|
||||||
|
noise_distribution = 'laplace'
|
||||||
|
epsilon = 2
|
||||||
|
class_weights = 1 # Previously, the fit method auto-detected the class_weights.
|
||||||
|
# Here, we need to pass the class_weights explicitly. 1 is the same as None.
|
||||||
|
n_samples = 20
|
||||||
|
batch_size = 5
|
||||||
|
|
||||||
|
with optimizer(
|
||||||
|
noise_distribution=noise_distribution,
|
||||||
|
epsilon=epsilon,
|
||||||
|
layers=test_model.layers,
|
||||||
|
class_weights=class_weights,
|
||||||
|
n_samples=n_samples,
|
||||||
|
batch_size=batch_size
|
||||||
|
) as _:
|
||||||
|
test_model.fit(x, y, batch_size=batch_size, epochs=2)
|
Loading…
Reference in a new issue