Merge pull request #53 from georgianpartners:master

PiperOrigin-RevId: 260990063
This commit is contained in:
A. Unique TensorFlower 2019-07-31 13:44:10 -07:00
commit 9fe5e91de4
10 changed files with 2815 additions and 0 deletions

View file

@ -41,3 +41,9 @@ else:
from privacy.optimizers.dp_optimizer import DPAdamOptimizer from privacy.optimizers.dp_optimizer import DPAdamOptimizer
from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer
from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer from privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer
from privacy.bolt_on.models import BoltOnModel
from privacy.bolt_on.optimizers import BoltOn
from privacy.bolt_on.losses import StrongConvexMixin
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from privacy.bolt_on.losses import StrongConvexHuber

57
privacy/bolt_on/README.md Normal file
View file

@ -0,0 +1,57 @@
# BoltOn Subpackage
This package contains source code for the BoltOn method, a particular
differential-privacy (DP) technique that uses output perturbations and
leverages additional assumptions to provide a new way of approaching the
privacy guarantees.
## BoltOn Description
This method uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R, the radius of the hypothesis space,
after each batch. This value is configurable by the user.
3. Limits learning rate
4. Uses a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf
## Why BoltOn?
The major difference for the BoltOn method is that it injects noise post model
convergence, rather than noising gradients or weights during training. This
approach requires some additional constraints listed in the Description.
Should the use-case and model satisfy these constraints, this is another
approach that can be trained to maximize utility while maintaining the privacy.
The paper describes in detail the advantages and disadvantages of this approach
and its results compared to some other methods, namely noising at each iteration
and no noising.
## Tutorials
This package has a tutorial that can be found in the root tutorials directory,
under `bolton_tutorial.py`.
## Contribution
This package was initially contributed by Georgian Partners with the hope of
growing the tensorflow/privacy library. There are several rich use cases for
delta-epsilon privacy in machine learning, some of which can be explored here:
https://medium.com/apache-mxnet/epsilon-differential-privacy-for-machine-learning-using-mxnet-a4270fe3865e
https://arxiv.org/pdf/1811.04911.pdf
## Contacts
In addition to the maintainers of tensorflow/privacy listed in the root
README.md, please feel free to contact members of Georgian Partners. In
particular,
* Georgian Partners(@georgianpartners)
* Ji Chao Zhang(@Jichaogp)
* Christopher Choquette(@cchoquette)
## Copyright
Copyright 2019 - Google LLC

View file

@ -0,0 +1,29 @@
# Copyright 2019, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn Method for privacy."""
import sys
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
raise ImportError("Please upgrade your version "
"of tensorflow from: {0} to at least 2.0.0 to "
"use privacy/bolt_on".format(LooseVersion(tf.__version__)))
if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts.
pass
else:
from privacy.bolt_on.models import BoltOnModel # pylint: disable=g-import-not-at-top
from privacy.bolt_on.optimizers import BoltOn # pylint: disable=g-import-not-at-top
from privacy.bolt_on.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top

304
privacy/bolt_on/losses.py Normal file
View file

@ -0,0 +1,304 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loss functions for BoltOn method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import losses
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.platform import tf_logging as logging
class StrongConvexMixin: # pylint: disable=old-style-class
"""Strong Convex Mixin base class.
Strong Convex Mixin base class for any loss function that will be used with
BoltOn model. Subclasses must be strongly convex and implement the
associated constants. They must also conform to the requirements of tf losses
(see super class).
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
R
"""
raise NotImplementedError("Radius not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def gamma(self):
"""Returns strongly convex parameter, gamma."""
raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def beta(self, class_weight):
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
raise NotImplementedError("Beta not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def lipchitz_constant(self, class_weight):
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns: L
"""
raise NotImplementedError("lipchitz constant not implemented for "
"StrongConvex Loss"
"function: %s" % str(self.__class__.__name__))
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return None
def max_class_weight(self, class_weight, dtype):
"""The maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
return tf.math.reduce_max(class_weight)
class StrongConvexHuber(losses.Loss, StrongConvexMixin):
"""Strong Convex version of Huber loss using l2 weight regularization."""
def __init__(self,
reg_lambda,
c_arg,
radius_constant,
delta,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""Constructor.
Args:
reg_lambda: Weight regularization constant
c_arg: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius
delta: delta value in huber loss. When to switch from quadratic to
absolute deviation.
reduction: reduction type to use. See super class
dtype: tf datatype to use for tensor conversions.
Returns:
Loss values per sample.
"""
if c_arg <= 0:
raise ValueError("c: {0}, should be >= 0".format(c_arg))
if reg_lambda <= 0:
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
if radius_constant <= 0:
raise ValueError("radius_constant: {0}, should be >= 0".format(
radius_constant
))
if delta <= 0:
raise ValueError("delta: {0}, should be >= 0".format(
delta
))
self.C = c_arg # pylint: disable=invalid-name
self.delta = delta
self.radius_constant = radius_constant
self.dtype = dtype
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexHuber, self).__init__(
name="strongconvexhuber",
reduction=reduction,
)
def call(self, y_true, y_pred):
"""Computes loss.
Args:
y_true: Ground truth values. One hot encoded using -1 and 1.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
h = self.delta
z = y_pred * y_true
one = tf.constant(1, dtype=self.dtype)
four = tf.constant(4, dtype=self.dtype)
if z > one + h: # pylint: disable=no-else-return
return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
elif tf.math.abs(one - z) <= h:
return one / (four * h) * tf.math.pow(one + h - z, 2)
return one - z
def radius(self):
"""See super class."""
return self.radius_constant / self.reg_lambda
def gamma(self):
"""See super class."""
return self.reg_lambda
def beta(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
delta = _ops.convert_to_tensor_v2(self.delta,
dtype=self.dtype
)
return self.C * max_class_weight / (delta *
tf.constant(2, dtype=self.dtype)) + \
self.reg_lambda
def lipchitz_constant(self, class_weight):
"""See super class."""
# if class_weight is provided,
# it should be a vector of the same size of number of classes
max_class_weight = self.max_class_weight(class_weight, self.dtype)
lc = self.C * max_class_weight + \
self.reg_lambda * self.radius()
return lc
def kernel_regularizer(self):
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
L2 regularization is required for this loss function to be strongly convex.
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda/2)
class StrongConvexBinaryCrossentropy(
losses.BinaryCrossentropy,
StrongConvexMixin
):
"""Strongly Convex BinaryCrossentropy loss using l2 weight regularization."""
def __init__(self,
reg_lambda,
c_arg,
radius_constant,
from_logits=True,
label_smoothing=0,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""StrongConvexBinaryCrossentropy class.
Args:
reg_lambda: Weight regularization constant
c_arg: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius
from_logits: True if the input are unscaled logits. False if they are
already scaled.
label_smoothing: amount of smoothing to perform on labels
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
impact of this parameter's effect on privacy is not known and thus the
default should be used.
reduction: reduction type to use. See super class
dtype: tf datatype to use for tensor conversions.
"""
if label_smoothing != 0:
logging.warning("The impact of label smoothing on privacy is unknown. "
"Use label smoothing at your own risk as it may not "
"guarantee privacy.")
if reg_lambda <= 0:
raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
if c_arg <= 0:
raise ValueError("c: {0}, should be >= 0".format(c_arg))
if radius_constant <= 0:
raise ValueError("radius_constant: {0}, should be >= 0".format(
radius_constant
))
self.dtype = dtype
self.C = c_arg # pylint: disable=invalid-name
self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
super(StrongConvexBinaryCrossentropy, self).__init__(
reduction=reduction,
name="strongconvexbinarycrossentropy",
from_logits=from_logits,
label_smoothing=label_smoothing,
)
self.radius_constant = radius_constant
def call(self, y_true, y_pred):
"""Computes loss.
Args:
y_true: Ground truth values.
y_pred: The predicted values.
Returns:
Loss values per sample.
"""
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
loss = loss * self.C
return loss
def radius(self):
"""See super class."""
return self.radius_constant / self.reg_lambda
def gamma(self):
"""See super class."""
return self.reg_lambda
def beta(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda
def lipchitz_constant(self, class_weight):
"""See super class."""
max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda * self.radius()
def kernel_regularizer(self):
"""Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
L2 regularization is required for this loss function to be strongly convex.
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
"""
return L1L2(l2=self.reg_lambda/2)

View file

@ -0,0 +1,431 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from contextlib import contextmanager # pylint: disable=g-importing-member
from io import StringIO # pylint: disable=g-importing-member
import sys
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.regularizers import L1L2
from privacy.bolt_on.losses import StrongConvexBinaryCrossentropy
from privacy.bolt_on.losses import StrongConvexHuber
from privacy.bolt_on.losses import StrongConvexMixin
@contextmanager
def captured_output():
"""Capture std_out and std_err within context."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin."""
@parameterized.named_parameters([
{'testcase_name': 'beta not implemented',
'fn': 'beta',
'args': [1]},
{'testcase_name': 'gamma not implemented',
'fn': 'gamma',
'args': []},
{'testcase_name': 'lipchitz not implemented',
'fn': 'lipchitz_constant',
'args': [1]},
{'testcase_name': 'radius not implemented',
'fn': 'radius',
'args': []},
])
def test_not_implemented(self, fn, args):
"""Test that the given fn's are not implemented on the mixin.
Args:
fn: fn on Mixin to test
args: arguments to fn of Mixin
"""
with self.assertRaises(NotImplementedError):
loss = StrongConvexMixin()
getattr(loss, fn, None)(*args)
@parameterized.named_parameters([
{'testcase_name': 'radius not implemented',
'fn': 'kernel_regularizer',
'args': []},
])
def test_return_none(self, fn, args):
"""Test that fn of Mixin returns None.
Args:
fn: fn of Mixin to test
args: arguments to fn of Mixin
"""
loss = StrongConvexMixin()
ret = getattr(loss, fn, None)(*args)
self.assertEqual(ret, None)
class BinaryCrossesntropyTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
self.assertIsInstance(loss, StrongConvexBinaryCrossentropy)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'C': -1,
'radius_constant': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'C': 1,
'radius_constant': -1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'C': 1,
'radius_constant': 1
}, # pylint: disable=invalid-name
])
def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
radius_constant: initialization value for radius_constant arg
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
# [] for compatibility with tensorflow loss calculation
{'testcase_name': 'both positive',
'logits': [10000],
'y_true': [1],
'result': 0,
},
{'testcase_name': 'positive gradient negative logits',
'logits': [-10000],
'y_true': [1],
'result': 10000,
},
{'testcase_name': 'positivee gradient positive logits',
'logits': [10000],
'y_true': [0],
'result': 10000,
},
{'testcase_name': 'both negative',
'logits': [-10000],
'y_true': [0],
'result': 0
},
])
def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value.
Args:
logits: unscaled output of model
y_true: label
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1)
loss = loss(y_true, logits)
self.assertEqual(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.constant(2, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1],
'args': [],
'result': tf.constant(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1],
'args': [1],
'result': tf.constant(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1],
'args': [],
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexBinaryCrossentropy(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
@parameterized.named_parameters([
{'testcase_name': 'label_smoothing',
'init_args': [1, 1, 1, True, 0.1],
'fn': None,
'args': None,
'print_res': 'The impact of label smoothing on privacy is unknown.'
},
])
def test_prints(self, init_args, fn, args, print_res):
"""Test logger warning from StrongConvexBinaryCrossentropy.
Args:
init_args: arguments to init the object with.
fn: function to test
args: arguments to above function
print_res: print result that should have been printed.
"""
with captured_output() as (out, err): # pylint: disable=unused-variable
loss = StrongConvexBinaryCrossentropy(*init_args)
if fn is not None:
getattr(loss, fn, lambda *arguments: print('error'))(*args)
self.assertRegexMatch(err.getvalue().strip(), [print_res])
class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1,
'delta': 1,
},
])
def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments.
Args:
reg_lambda: initialization value for reg_lambda arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
# test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
self.assertIsInstance(loss, StrongConvexHuber)
@parameterized.named_parameters([
{'testcase_name': 'negative c',
'reg_lambda': 1,
'c': -1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative radius',
'reg_lambda': 1,
'c': 1,
'radius_constant': -1,
'delta': 1
},
{'testcase_name': 'negative lambda',
'reg_lambda': -1,
'c': 1,
'radius_constant': 1,
'delta': 1
},
{'testcase_name': 'negative delta',
'reg_lambda': 1,
'c': 1,
'radius_constant': 1,
'delta': -1
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
# test valid domains for each variable
with self.assertRaises(ValueError):
StrongConvexHuber(reg_lambda, c, radius_constant, delta)
# test the bounds and test varied delta's
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary',
'logits': 2.1,
'y_true': 1,
'delta': 1,
'result': 0,
},
{'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary',
'logits': 1.9,
'y_true': 1,
'delta': 1,
'result': 0.01*0.25,
},
{'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary',
'logits': 0.1,
'y_true': 1,
'delta': 1,
'result': 1.9**2 * 0.25,
},
{'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary',
'logits': -0.1,
'y_true': 1,
'delta': 1,
'result': 1.1,
},
{'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary',
'logits': 3.1,
'y_true': 1,
'delta': 2,
'result': 0,
},
{'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary',
'logits': 2.9,
'y_true': 1,
'delta': 2,
'result': 0.01*0.125,
},
{'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary',
'logits': 1.1,
'y_true': 1,
'delta': 2,
'result': 1.9**2 * 0.125,
},
{'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary',
'logits': -1.1,
'y_true': 1,
'delta': 2,
'result': 2.1,
},
{'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary',
'logits': -2.1,
'y_true': -1,
'delta': 1,
'result': 0,
},
])
def test_calculation(self, logits, y_true, delta, result):
"""Test the call method to ensure it returns the correct value.
Args:
logits: unscaled output of model
y_true: label
delta: delta value for StrongConvexHuber loss.
result: correct loss calculation value
"""
logits = tf.Variable(logits, False, dtype=tf.float32)
y_true = tf.Variable(y_true, False, dtype=tf.float32)
loss = StrongConvexHuber(0.00001, 1, 1, delta)
loss = loss(y_true, logits)
self.assertAllClose(loss.numpy(), result)
@parameterized.named_parameters([
{'testcase_name': 'beta',
'init_args': [1, 1, 1, 1],
'fn': 'beta',
'args': [1],
'result': tf.Variable(1.5, dtype=tf.float32)
},
{'testcase_name': 'gamma',
'fn': 'gamma',
'init_args': [1, 1, 1, 1],
'args': [],
'result': tf.Variable(1, dtype=tf.float32),
},
{'testcase_name': 'lipchitz constant',
'fn': 'lipchitz_constant',
'init_args': [1, 1, 1, 1],
'args': [1],
'result': tf.Variable(2, dtype=tf.float32),
},
{'testcase_name': 'kernel regularizer',
'fn': 'kernel_regularizer',
'init_args': [1, 1, 1, 1],
'args': [],
'result': L1L2(l2=0.5),
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance
fn: the fn to test
args: the arguments to above function
result: the correct result from the fn
"""
loss = StrongConvexHuber(*init_args)
expected = getattr(loss, fn, lambda: 'fn not found')(*args)
if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor
expected = expected.numpy()
result = result.numpy()
if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer
expected = expected.l2
result = result.l2
self.assertEqual(expected, result)
if __name__ == '__main__':
tf.test.main()

297
privacy/bolt_on/models.py Normal file
View file

@ -0,0 +1,297 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn model for Bolt-on method of differentially private ML."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import Model
from privacy.bolt_on.losses import StrongConvexMixin
from privacy.bolt_on.optimizers import BoltOn
class BoltOnModel(Model): # pylint: disable=abstract-method
"""BoltOn episilon-delta differential privacy model.
The privacy guarantees are dependent on the noise that is sampled. Please
see the paper linked below for more details.
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et al.
"""
def __init__(self,
n_outputs,
seed=1,
dtype=tf.float32):
"""Private constructor.
Args:
n_outputs: number of output classes to predict.
seed: random seed to use
dtype: data type to use for tensors
"""
super(BoltOnModel, self).__init__(name='bolton', dynamic=False)
if n_outputs <= 0:
raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format(
n_outputs
))
self.n_outputs = n_outputs
self.seed = seed
self._layers_instantiated = False
self._dtype = dtype
def call(self, inputs): # pylint: disable=arguments-differ
"""Forward pass of network.
Args:
inputs: inputs to neural network
Returns:
Output logits for the given inputs.
"""
return self.output_layer(inputs)
def compile(self,
optimizer,
loss,
kernel_initializer=tf.initializers.GlorotUniform,
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in BoltOn method is SGD.
Args:
optimizer: The optimizer to use. This will be automatically wrapped
with the BoltOn Optimizer.
loss: The loss function to use. Must be a StrongConvex loss (extend the
StrongConvexMixin).
kernel_initializer: The kernel initializer to use for the single layer.
**kwargs: kwargs to keras Model.compile. See super.
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
if not self._layers_instantiated: # compile may be called multiple times
# for instance, if the input/outputs are not defined until fit.
self.output_layer = tf.keras.layers.Dense(
self.n_outputs,
kernel_regularizer=loss.kernel_regularizer(),
kernel_initializer=kernel_initializer(),
)
self._layers_instantiated = True
if not isinstance(optimizer, BoltOn):
optimizer = optimizers.get(optimizer)
optimizer = BoltOn(optimizer, loss)
super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs)
def fit(self,
x=None,
y=None,
batch_size=None,
class_weight=None,
n_samples=None,
epsilon=2,
noise_distribution='laplace',
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Reroutes to super fit with BoltOn delta-epsilon privacy requirements.
Note, inputs must be normalized s.t. ||x|| < 1.
Requirements are as follows:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
See super implementation for more details.
Args:
x: Inputs to fit on, see super.
y: Labels to fit on, see super.
batch_size: The batch size to use for training, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
n_samples: the number of individual samples in x.
epsilon: privacy parameter, which trades off between utility an privacy.
See the bolt-on paper for more description.
noise_distribution: the distribution to pull noise from.
steps_per_epoch:
**kwargs: kwargs to keras Model.fit. See super.
Returns:
Output from super fit method.
"""
if class_weight is None:
class_weight_ = self.calculate_class_weights(class_weight)
else:
class_weight_ = class_weight
if n_samples is not None:
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, '__len__'):
data_size = len(x)
else:
data_size = None
batch_size_ = self._validate_or_infer_batch_size(batch_size,
steps_per_epoch,
x)
# inferring batch_size to be passed to optimizer. batch_size must remain its
# initial value when passed to super().fit()
if batch_size_ is None:
raise ValueError('batch_size: {0} is an '
'invalid value'.format(batch_size_))
if data_size is None:
raise ValueError('Could not infer the number of samples. Please pass '
'this in using n_samples.')
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight_,
data_size,
batch_size_) as _:
out = super(BoltOnModel, self).fit(x=x,
y=y,
batch_size=batch_size,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
**kwargs)
return out
def fit_generator(self,
generator,
class_weight=None,
noise_distribution='laplace',
epsilon=2,
n_samples=None,
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""Fit with a generator.
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
generator: Inputs generator following Tensorflow guidelines, see super.
class_weight: the class weights to be used. Can be a scalar or 1D tensor
whose dim == n_classes.
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
BoltOn paper for more description.
n_samples: number of individual samples in x
steps_per_epoch: Number of steps per training epoch, see super.
**kwargs: **kwargs
Returns:
Output from super fit_generator method.
"""
if class_weight is None:
class_weight = self.calculate_class_weights(class_weight)
if n_samples is not None:
data_size = n_samples
elif hasattr(generator, 'shape'):
data_size = generator.shape[0]
elif hasattr(generator, '__len__'):
data_size = len(generator)
else:
data_size = None
batch_size = self._validate_or_infer_batch_size(None,
steps_per_epoch,
generator)
with self.optimizer(noise_distribution,
epsilon,
self.layers,
class_weight,
data_size,
batch_size) as _:
out = super(BoltOnModel, self).fit_generator(
generator,
class_weight=class_weight,
steps_per_epoch=steps_per_epoch,
**kwargs)
return out
def calculate_class_weights(self,
class_weights=None,
class_counts=None,
num_classes=None):
"""Calculates class weighting to be used in training.
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then an array of
the number of samples for each class
num_classes: If class_weights is not None, then the number of
classes.
Returns:
class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
is_string = False
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError('Detected string class_weights with '
'value: {0}, which is not one of {1}.'
'Please select a valid class_weight type'
'or pass an array'.format(class_weights,
class_keys))
if class_counts is None:
raise ValueError('Class counts must be provided if using '
'class_weights=%s' % class_weights)
class_counts_shape = tf.Variable(class_counts,
trainable=False,
dtype=self._dtype).shape
if len(class_counts_shape) != 1:
raise ValueError('class counts must be a 1D array.'
'Detected: {0}'.format(class_counts_shape))
if num_classes is None:
raise ValueError('num_classes must be provided if using '
'class_weights=%s' % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError('You must pass a value for num_classes if '
'creating an array of class_weights')
# performing class weight calculation
if class_weights is None:
class_weights = 1
elif is_string and class_weights == 'balanced':
num_samples = sum(class_counts)
weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes,
class_counts),
self._dtype)
class_weights = tf.Variable(num_samples, dtype=self._dtype) / \
tf.Variable(weighted_counts, dtype=self._dtype)
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError('Detected class_weights shape: {0} instead of '
'1D array'.format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
'Detected array length: {0} instead of: {1}'.format(
class_weights.shape[0],
num_classes))
return class_weights

View file

@ -0,0 +1,534 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from privacy.bolt_on import models
from privacy.bolt_on.losses import StrongConvexMixin
from privacy.bolt_on.optimizers import BoltOn
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = c_arg # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self):
"""Returns strongly convex parameter, gamma."""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns:
L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight):
"""the maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None:
return 1
raise ValueError('')
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing BoltOn model."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
def compute_gradients(self):
return 0
def get_config(self):
return {}
def _create_slots(self, var):
pass
def _resource_apply_dense(self, grad, handle):
return grad
def _resource_apply_sparse(self, grad, handle, indices):
return grad
class InitTests(keras_parameterized.TestCase):
"""Tests for keras model initialization."""
@parameterized.named_parameters([
{'testcase_name': 'normal',
'n_outputs': 1,
},
{'testcase_name': 'many outputs',
'n_outputs': 100,
},
])
def test_init_params(self, n_outputs):
"""Test initialization of BoltOnModel.
Args:
n_outputs: number of output neurons
"""
# test valid domains for each variable
clf = models.BoltOnModel(n_outputs)
self.assertIsInstance(clf, models.BoltOnModel)
@parameterized.named_parameters([
{'testcase_name': 'invalid n_outputs',
'n_outputs': -1,
},
])
def test_bad_init_params(self, n_outputs):
"""test bad initializations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
"""
# test invalid domains for each variable, especially noise
with self.assertRaises(ValueError):
models.BoltOnModel(n_outputs)
@parameterized.named_parameters([
{'testcase_name': 'string compile',
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'adam',
},
{'testcase_name': 'test compile',
'n_outputs': 100,
'loss': TestLoss(1, 1, 1),
'optimizer': TestOptimizer(),
},
])
def test_compile(self, n_outputs, loss, optimizer):
"""Test compilation of BoltOnModel.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instantiated TestOptimizer instance
"""
# test compilation of valid tf.optimizer and tf.loss
with self.cached_session():
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
self.assertEqual(clf.loss, loss)
@parameterized.named_parameters([
{'testcase_name': 'Not strong loss',
'n_outputs': 1,
'loss': losses.BinaryCrossentropy(),
'optimizer': 'adam',
},
{'testcase_name': 'Not valid optimizer',
'n_outputs': 1,
'loss': TestLoss(1, 1, 1),
'optimizer': 'ada',
}
])
def test_bad_compile(self, n_outputs, loss, optimizer):
"""test bad compilations of BoltOnModel that should raise errors.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instantiated TestOptimizer instance
"""
# test compilaton of invalid tf.optimizer and non instantiated loss.
with self.cached_session():
with self.assertRaises((ValueError, AttributeError)):
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
"""Creates a categorically encoded dataset.
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
Will have evenly split samples across each output class.
Each output class will be a different point in the input space.
Args:
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
generator: False for array, True for generator
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
"""
x_stack = []
y_stack = []
for i_class in range(n_classes):
x_stack.append(
tf.constant(1*i_class, tf.float32, (n_samples, input_dim))
)
y_stack.append(
tf.constant(i_class, tf.float32, (n_samples, n_classes))
)
x_set, y_set = tf.stack(x_stack), tf.stack(y_stack)
if generator:
dataset = tf.data.Dataset.from_tensor_slices(
(x_set, y_set)
)
return dataset
return x_set, y_set
def _do_fit(n_samples,
input_dim,
n_outputs,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
distribution='laplace'):
"""Instantiate necessary components for fitting and perform a model fit.
Args:
n_samples: number of samples in dataset
input_dim: the sample dimensionality
n_outputs: number of output neurons
epsilon: privacy parameter
generator: True to create a generator, False to use an iterator
batch_size: batch_size to use
reset_n_samples: True to set _samples to None prior to fitting.
False does nothing
optimizer: instance of TestOptimizer
loss: instance of TestLoss
distribution: distribution to get noise from.
Returns:
BoltOnModel instsance
"""
clf = models.BoltOnModel(n_outputs)
clf.compile(optimizer, loss)
if generator:
x = _cat_dataset(
n_samples,
input_dim,
n_outputs,
generator=generator
)
y = None
# x = x.batch(batch_size)
x = x.shuffle(n_samples//2)
batch_size = None
else:
x, y = _cat_dataset(n_samples, input_dim, n_outputs, generator=generator)
if reset_n_samples:
n_samples = None
clf.fit(x,
y,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=distribution,
epsilon=epsilon)
return clf
class FitTests(keras_parameterized.TestCase):
"""Test cases for keras model fitting."""
# @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'iterator fit',
'generator': False,
'reset_n_samples': True,
},
{'testcase_name': 'iterator fit no samples',
'generator': False,
'reset_n_samples': True,
},
{'testcase_name': 'generator fit',
'generator': True,
'reset_n_samples': False,
},
{'testcase_name': 'with callbacks',
'generator': True,
'reset_n_samples': False,
},
])
def test_fit(self, generator, reset_n_samples):
"""Tests fitting of BoltOnModel.
Args:
generator: True for generator test, False for iterator test.
reset_n_samples: True to reset the n_samples to None, False does nothing
"""
loss = TestLoss(1, 1, 1)
optimizer = BoltOn(TestOptimizer(), loss)
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
clf = _do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'generator fit',
'generator': True,
},
])
def test_fit_gen(self, generator):
"""Tests the fit_generator method of BoltOnModel.
Args:
generator: True to test with a generator dataset
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
batch_size = 1
n_samples = 10
clf = models.BoltOnModel(n_classes)
clf.compile(optimizer, loss)
x = _cat_dataset(
n_samples,
input_dim,
n_classes,
generator=generator
)
x = x.batch(batch_size)
x = x.shuffle(n_samples // 2)
clf.fit_generator(x, n_samples=n_samples)
self.assertEqual(hasattr(clf, 'layers'), True)
@parameterized.named_parameters([
{'testcase_name': 'iterator no n_samples',
'generator': True,
'reset_n_samples': True,
'distribution': 'laplace'
},
{'testcase_name': 'invalid distribution',
'generator': True,
'reset_n_samples': True,
'distribution': 'not_valid'
},
])
def test_bad_fit(self, generator, reset_n_samples, distribution):
"""Tests fitting with invalid parameters, which should raise an error.
Args:
generator: True to test with generator, False is iterator
reset_n_samples: True to reset the n_samples param to None prior to
passing it to fit
distribution: distribution to get noise from.
"""
with self.assertRaises(ValueError):
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
n_classes = 2
input_dim = 5
epsilon = 1
batch_size = 1
n_samples = 10
_do_fit(
n_samples,
input_dim,
n_classes,
epsilon,
generator,
batch_size,
reset_n_samples,
optimizer,
loss,
distribution
)
@parameterized.named_parameters([
{'testcase_name': 'None class_weights',
'class_weights': None,
'class_counts': None,
'num_classes': None,
'result': 1},
{'testcase_name': 'class weights array',
'class_weights': [1, 1],
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
{'testcase_name': 'class weights balanced',
'class_weights': 'balanced',
'class_counts': [1, 1],
'num_classes': 2,
'result': [1, 1]},
])
def test_class_calculate(self,
class_weights,
class_counts,
num_classes,
result):
"""Tests the BOltonModel calculate_class_weights method.
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
result: expected result
"""
clf = models.BoltOnModel(1, 1)
expected = clf.calculate_class_weights(class_weights,
class_counts,
num_classes)
if hasattr(expected, 'numpy'):
expected = expected.numpy()
self.assertAllEqual(
expected,
result
)
@parameterized.named_parameters([
{'testcase_name': 'class weight not valid str',
'class_weights': 'not_valid',
'class_counts': 1,
'num_classes': 1,
'err_msg': 'Detected string class_weights with value: not_valid'},
{'testcase_name': 'no class counts',
'class_weights': 'balanced',
'class_counts': None,
'num_classes': 1,
'err_msg': 'Class counts must be provided if '
'using class_weights=balanced'},
{'testcase_name': 'no num classes',
'class_weights': 'balanced',
'class_counts': [1],
'num_classes': None,
'err_msg': 'num_classes must be provided if '
'using class_weights=balanced'},
{'testcase_name': 'class counts not array',
'class_weights': 'balanced',
'class_counts': 1,
'num_classes': None,
'err_msg': 'class counts must be a 1D array.'},
{'testcase_name': 'class counts array, no num classes',
'class_weights': [1],
'class_counts': None,
'num_classes': None,
'err_msg': 'You must pass a value for num_classes if '
'creating an array of class_weights'},
{'testcase_name': 'class counts array, improper shape',
'class_weights': [[1], [1]],
'class_counts': None,
'num_classes': 2,
'err_msg': 'Detected class_weights shape'},
{'testcase_name': 'class counts array, wrong number classes',
'class_weights': [1, 1, 1],
'class_counts': None,
'num_classes': 2,
'err_msg': 'Detected array length:'},
])
def test_class_errors(self,
class_weights,
class_counts,
num_classes,
err_msg):
"""Tests the BOltonModel calculate_class_weights method.
This test passes invalid params which should raise the expected errors.
Args:
class_weights: the class_weights to use.
class_counts: count of number of samples for each class.
num_classes: number of outputs neurons.
err_msg: The expected error message.
"""
clf = models.BoltOnModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights,
class_counts,
num_classes)
if __name__ == '__main__':
tf.test.main()

View file

@ -0,0 +1,390 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BoltOn Optimizer for Bolt-on method."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import math_ops
from privacy.bolt_on.losses import StrongConvexMixin
_accepted_distributions = ['laplace'] # implemented distributions for noising
class GammaBetaDecreasingStep(
optimizer_v2.learning_rate_schedule.LearningRateSchedule):
"""Computes LR as minimum of 1/beta and 1/(gamma * step) at each step.
This is a required step for privacy guarantees.
"""
def __init__(self):
self.is_init = False
self.beta = None
self.gamma = None
def __call__(self, step):
"""Computes and returns the learning rate.
Args:
step: the current iteration number
Returns:
decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per
the BoltOn privacy requirements.
"""
if not self.is_init:
raise AttributeError('Please initialize the {0} Learning Rate Scheduler.'
'This is performed automatically by using the '
'{1} as a context manager, '
'as desired'.format(self.__class__.__name__,
BoltOn.__class__.__name__
)
)
dtype = self.beta.dtype
one = tf.constant(1, dtype)
return tf.math.minimum(tf.math.reduce_min(one/self.beta),
one/(self.gamma*math_ops.cast(step, dtype))
)
def get_config(self):
"""Return config to setup the learning rate scheduler."""
return {'beta': self.beta, 'gamma': self.gamma}
def initialize(self, beta, gamma):
"""Setups scheduler with beta and gamma values from the loss function.
Meant to be used with .fit as the loss params may depend on values passed to
fit.
Args:
beta: Smoothness value. See StrongConvexMixin
gamma: Strong Convexity parameter. See StrongConvexMixin.
"""
self.is_init = True
self.beta = beta
self.gamma = gamma
def de_initialize(self):
"""De initialize post fit, as another fit call may use other parameters."""
self.is_init = False
self.beta = None
self.gamma = None
class BoltOn(optimizer_v2.OptimizerV2):
"""Wrap another tf optimizer with BoltOn privacy protocol.
BoltOn optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "BoltOn" enables the bolt-on model to control the learning rate
based on the strongly convex loss.
To use the BoltOn method, you must:
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
2. use it as a context manager around your .fit method internals.
This can be accomplished by the following:
optimizer = tf.optimizers.SGD()
loss = privacy.bolt_on.losses.StrongConvexBinaryCrossentropy()
bolton = BoltOn(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
method.
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, # pylint: disable=super-init-not-called
optimizer,
loss,
dtype=tf.float32,
):
"""Constructor.
Args:
optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped).
loss: StrongConvexLoss function that the model is being compiled with.
dtype: dtype
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
self._private_attributes = ['_internal_optimizer',
'dtype',
'noise_distribution',
'epsilon',
'loss',
'class_weights',
'input_dim',
'n_samples',
'layers',
'batch_size',
'_is_init'
]
self._internal_optimizer = optimizer
self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning
# rate scheduler, as required for privacy guarantees. This will still need
# to get values from the loss function near the time that .fit is called
# on the model (when this optimizer will be called as a context manager)
self.dtype = dtype
self.loss = loss
self._is_init = False
def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.get_config()
def project_weights_to_r(self, force=False):
"""Normalize the weights to the R-ball.
Args:
force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then.
Raises:
Exception: If not called from inside this optimizer context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
'context.')
radius = self.loss.radius()
for layer in self.layers:
weight_norm = tf.norm(layer.kernel, axis=0)
if force:
layer.kernel = layer.kernel / (weight_norm / radius)
else:
layer.kernel = tf.cond(
tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0,
lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop
lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop
)
def get_noise(self, input_dim, output_dim):
"""Sample noise to be added to weights for privacy guarantee.
Args:
input_dim: the input dimensionality for the weights
output_dim: the output dimensionality for the weights
Returns:
Noise in shape of layer's weights to be added to the weights.
Raises:
Exception: If not called from inside this optimizer's context.
"""
if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s '
'context.')
loss = self.loss
distribution = self.noise_distribution.lower()
if distribution == _accepted_distributions[0]: # laplace
per_class_epsilon = self.epsilon / (output_dim)
l2_sensitivity = (2 *
loss.lipchitz_constant(self.class_weights)) / \
(loss.gamma() * self.n_samples * self.batch_size)
unit_vector = tf.random.normal(shape=(input_dim, output_dim),
mean=0,
seed=1,
stddev=1.0,
dtype=self.dtype)
unit_vector = unit_vector / tf.math.sqrt(
tf.reduce_sum(tf.math.square(unit_vector), axis=0)
)
beta = l2_sensitivity / per_class_epsilon
alpha = input_dim # input_dim
gamma = tf.random.gamma([output_dim],
alpha,
beta=1 / beta,
seed=1,
dtype=self.dtype
)
return unit_vector * gamma
raise NotImplementedError('Noise distribution: {0} is not '
'a valid distribution'.format(distribution))
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.from_config(*args, **kwargs)
def __getattr__(self, name):
"""Get attr.
return _internal_optimizer off self instance, and everything else
from the _internal_optimizer instance.
Args:
name: Name of attribute to get from this or aggregate optimizer.
Returns:
attribute from BoltOn if specified to come from self, else
from _internal_optimizer.
"""
if name == '_private_attributes' or name in self._private_attributes:
return getattr(self, name)
optim = object.__getattribute__(self, '_internal_optimizer')
try:
return object.__getattribute__(optim, name)
except AttributeError:
raise AttributeError(
"Neither '{0}' nor '{1}' object has attribute '{2}'"
"".format(self.__class__.__name__,
self._internal_optimizer.__class__.__name__,
name
)
)
def __setattr__(self, key, value):
"""Set attribute to self instance if its the internal optimizer.
Reroute everything else to the _internal_optimizer.
Args:
key: attribute name
value: attribute value
"""
if key == '_private_attributes':
object.__setattr__(self, key, value)
elif key in self._private_attributes:
object.__setattr__(self, key, value)
else:
setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.get_updates(loss, params)
self.project_weights_to_r()
return out
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.project_weights_to_r()
return out
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
out = self._internal_optimizer.minimize(*args, **kwargs)
self.project_weights_to_r()
return out
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer."""
return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self):
"""Context manager call at the beginning of with statement.
Returns:
self, to be used in context manager
"""
self._is_init = True
return self
def __call__(self,
noise_distribution,
epsilon,
layers,
class_weights,
n_samples,
batch_size
):
"""Accepts required values for bolton method from context entry point.
Stores them on the optimizer for use throughout fitting.
Args:
noise_distribution: the noise distribution to pick.
see _accepted_distributions and get_noise for possible values.
epsilon: privacy parameter. Lower gives more privacy but less utility.
layers: list of Keras/Tensorflow layers. Can be found as model.layers
class_weights: class_weights used, which may either be a scalar or 1D
tensor with dim == n_classes.
n_samples: number of rows/individual samples in the training set
batch_size: batch size used.
Returns:
self, to be used by the __enter__ method for context.
"""
if epsilon <= 0:
raise ValueError('Detected epsilon: {0}. '
'Valid range is 0 < epsilon <inf'.format(epsilon))
if noise_distribution not in _accepted_distributions:
raise ValueError('Detected noise distribution: {0} not one of: {1} valid'
'distributions'.format(noise_distribution,
_accepted_distributions))
self.noise_distribution = noise_distribution
self.learning_rate.initialize(self.loss.beta(class_weights),
self.loss.gamma()
)
self.epsilon = tf.constant(epsilon, dtype=self.dtype)
self.class_weights = tf.constant(class_weights, dtype=self.dtype)
self.n_samples = tf.constant(n_samples, dtype=self.dtype)
self.layers = layers
self.batch_size = tf.constant(batch_size, dtype=self.dtype)
return self
def __exit__(self, *args):
"""Exit call from with statement.
Used to:
1.reset the model and fit parameters passed to the optimizer
to enable the BoltOn Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
Args:
*args: encompasses the type, value, and traceback values which are unused.
"""
self.project_weights_to_r(True)
for layer in self.layers:
input_dim = layer.kernel.shape[0]
output_dim = layer.units
noise = self.get_noise(input_dim,
output_dim,
)
layer.kernel = tf.math.add(layer.kernel, noise)
self.noise_distribution = None
self.learning_rate.de_initialize()
self.epsilon = -1
self.batch_size = -1
self.class_weights = None
self.n_samples = None
self.input_dim = None
self.layers = None
self._is_init = False

View file

@ -0,0 +1,579 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit testing for optimizers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python import ops as _ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.initializers import constant
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from tensorflow.python.platform import test
from privacy.bolt_on import optimizers as opt
from privacy.bolt_on.losses import StrongConvexMixin
class TestModel(Model): # pylint: disable=abstract-method
"""BoltOn episilon-delta model.
Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch
3. Limits learning rate
4. Use a strongly convex loss function (see compile)
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
"""Constructor.
Args:
n_outputs: number of output neurons
input_shape:
init_value:
"""
super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_outputs = n_outputs
self.layer_input_shape = input_shape
self.output_layer = tf.keras.layers.Dense(
self.n_outputs,
input_shape=self.layer_input_shape,
kernel_regularizer=L1L2(l2=1),
kernel_initializer=constant(init_value),
)
class TestLoss(losses.Loss, StrongConvexMixin):
"""Test loss function for testing BoltOn model."""
def __init__(self, reg_lambda, c_arg, radius_constant, name='test'):
super(TestLoss, self).__init__(name=name)
self.reg_lambda = reg_lambda
self.C = c_arg # pylint: disable=invalid-name
self.radius_constant = radius_constant
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns:
a tensor
"""
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
def gamma(self):
"""Returns strongly convex parameter, gamma."""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument
"""Smoothness, beta.
Args:
class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns:
Beta
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
"""Lipchitz constant, L.
Args:
class_weight: class weights used
Returns:
constant L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def call(self, y_true, y_pred):
"""Loss function that is minimized at the mean of the input points."""
return 0.5 * tf.reduce_sum(
tf.math.squared_difference(y_true, y_pred),
axis=1
)
def max_class_weight(self, class_weight, dtype=tf.float32):
"""the maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None:
return 1
raise NotImplementedError('')
def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda)
class TestOptimizer(OptimizerV2):
"""Optimizer used for testing the BoltOn optimizer."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
self.not_private = 'test'
self.iterations = tf.constant(1, dtype=tf.float32)
self._iterations = tf.constant(1, dtype=tf.float32)
def _compute_gradients(self, loss, var_list, grad_loss=None):
return 'test'
def get_config(self):
return 'test'
def from_config(self, config, custom_objects=None):
return 'test'
def _create_slots(self):
return 'test'
def _resource_apply_dense(self, grad, handle):
return 'test'
def _resource_apply_sparse(self, grad, handle, indices):
return 'test'
def get_updates(self, loss, params):
return 'test'
def apply_gradients(self, grads_and_vars, name=None):
return 'test'
def minimize(self, loss, var_list, grad_loss=None, name=None):
return 'test'
def get_gradients(self, loss, params):
return 'test'
def limit_learning_rate(self):
return 'test'
class BoltonOptimizerTest(keras_parameterized.TestCase):
"""BoltOn Optimizer tests."""
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'getattr',
'fn': '__getattr__',
'args': ['dtype'],
'result': tf.float32,
'test_attr': None},
{'testcase_name': 'project_weights_to_r',
'fn': 'project_weights_to_r',
'args': ['dtype'],
'result': None,
'test_attr': ''},
])
def test_fn(self, fn, args, result, test_attr):
"""test that a fn of BoltOn optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of BoltOn to check against result with.
"""
tf.random.set_seed(1)
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
res = getattr(bolton, fn, None)(*args)
if test_attr is not None:
res = getattr(bolton, test_attr, None)
if hasattr(res, 'numpy') and hasattr(result, 'numpy'): # both tensors/not
res = res.numpy()
result = result.numpy()
self.assertEqual(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': '1 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
{'testcase_name': '2 value project to r=1',
'r': 1,
'init_value': 2,
'shape': (2,),
'n_out': 1,
'result': [[0.707107], [0.707107]]},
{'testcase_name': '1 value project to r=2',
'r': 2,
'init_value': 3,
'shape': (1,),
'n_out': 1,
'result': [[2]]},
{'testcase_name': 'no project',
'r': 2,
'init_value': 1,
'shape': (1,),
'n_out': 1,
'result': [[1]]},
])
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of BoltOn optimizer is working as expected.
Args:
r: Radius value for StrongConvex loss function.
shape: input_dimensionality
n_out: output dimensionality
init_value: the initial value for 'constant' kernel initializer
result: the expected output after projection.
"""
tf.random.set_seed(1)
@tf.function
def project_fn(r):
loss = TestLoss(1, 1, r)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(n_out, shape, init_value)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
bolton.project_weights_to_r()
return _ops.convert_to_tensor_v2(bolton.layers[0].kernel, tf.float32)
res = project_fn(r)
self.assertAllClose(res, result)
@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
{'testcase_name': 'normal values',
'epsilon': 2,
'noise': 'laplace',
'class_weights': 1},
])
def test_context_manager(self, noise, epsilon, class_weights):
"""Tests the context manager functionality of the optimizer.
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
class_weights: class_weights to use
"""
@tf.function
def test_run():
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, class_weights, 1, 1) as _:
pass
return _ops.convert_to_tensor_v2(bolton.epsilon, dtype=tf.float32)
epsilon = test_run()
self.assertEqual(epsilon.numpy(), -1)
@parameterized.named_parameters([
{'testcase_name': 'invalid noise',
'epsilon': 1,
'noise': 'not_valid',
'err_msg': 'Detected noise distribution: not_valid not one of:'},
{'testcase_name': 'invalid epsilon',
'epsilon': -1,
'noise': 'laplace',
'err_msg': 'Detected epsilon: -1. Valid range is 0 < epsilon <inf'},
])
def test_context_domains(self, noise, epsilon, err_msg):
"""Tests the context domains.
Args:
noise: noise distribution to pick
epsilon: epsilon privacy parameter to use
err_msg: the expected error message
"""
@tf.function
def test_run(noise, epsilon):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
with bolton(noise, epsilon, model.layers, 1, 1, 1) as _:
pass
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
test_run(noise, epsilon)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1],
'err_msg': 'ust be called from within the optimizer\'s context'},
])
def test_not_in_context(self, fn, args, err_msg):
"""Tests that the expected functions raise errors when not in context.
Args:
fn: the function to test
args: the arguments for said function
err_msg: expected error message
"""
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
getattr(bolton, fn)(*args)
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
test_run(fn, args)
@parameterized.named_parameters([
{'testcase_name': 'fn: get_updates',
'fn': 'get_updates',
'args': [0, 0]},
{'testcase_name': 'fn: get_config',
'fn': 'get_config',
'args': []},
{'testcase_name': 'fn: from_config',
'fn': 'from_config',
'args': [0]},
{'testcase_name': 'fn: _resource_apply_dense',
'fn': '_resource_apply_dense',
'args': [1, 1]},
{'testcase_name': 'fn: _resource_apply_sparse',
'fn': '_resource_apply_sparse',
'args': [1, 1, 1]},
{'testcase_name': 'fn: apply_gradients',
'fn': 'apply_gradients',
'args': [1]},
{'testcase_name': 'fn: minimize',
'fn': 'minimize',
'args': [1, 1]},
{'testcase_name': 'fn: _compute_gradients',
'fn': '_compute_gradients',
'args': [1, 1]},
{'testcase_name': 'fn: get_gradients',
'fn': 'get_gradients',
'args': [1, 1]},
])
def test_rerouted_function(self, fn, args):
"""Tests rerouted function.
Tests that a method of the internal optimizer is correctly routed from
the BoltOn instance to the internal optimizer instance (TestOptimizer,
here).
Args:
fn: fn to test
args: arguments to that fn
"""
loss = TestLoss(1, 1, 1)
optimizer = TestOptimizer()
bolton = opt.BoltOn(optimizer, loss)
model = TestModel(3)
model.compile(optimizer, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.layers = model.layers
bolton.epsilon = 2
bolton.noise_distribution = 'laplace'
bolton.n_outputs = 1
bolton.n_samples = 1
self.assertEqual(
getattr(bolton, fn, lambda: 'fn not found')(*args),
'test'
)
@parameterized.named_parameters([
{'testcase_name': 'fn: project_weights_to_r',
'fn': 'project_weights_to_r',
'args': []},
{'testcase_name': 'fn: get_noise',
'fn': 'get_noise',
'args': [1, 1]},
])
def test_not_reroute_fn(self, fn, args):
"""Test function is not rerouted.
Test that a fn that should not be rerouted to the internal optimizer is
in fact not rerouted.
Args:
fn: fn to test
args: arguments to that fn
"""
@tf.function
def test_run(fn, args):
loss = TestLoss(1, 1, 1)
bolton = opt.BoltOn(TestOptimizer(), loss)
model = TestModel(1, (1,), 1)
model.compile(bolton, loss)
model.layers[0].kernel = \
model.layers[0].kernel_initializer((model.layer_input_shape[0],
model.n_outputs))
bolton._is_init = True # pylint: disable=protected-access
bolton.noise_distribution = 'laplace'
bolton.epsilon = 1
bolton.layers = model.layers
bolton.class_weights = 1
bolton.n_samples = 1
bolton.batch_size = 1
bolton.n_outputs = 1
res = getattr(bolton, fn, lambda: 'test')(*args)
if res != 'test':
res = 1
else:
res = 0
return _ops.convert_to_tensor_v2(res, dtype=tf.float32)
self.assertNotEqual(test_run(fn, args), 0)
@parameterized.named_parameters([
{'testcase_name': 'attr: _iterations',
'attr': '_iterations'}
])
def test_reroute_attr(self, attr):
"""Test a function is rerouted.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer.
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.BoltOn(internal_optimizer, loss)
self.assertEqual(getattr(optimizer, attr),
getattr(internal_optimizer, attr))
@parameterized.named_parameters([
{'testcase_name': 'attr does not exist',
'attr': '_not_valid'}
])
def test_attribute_error(self, attr):
"""Test rerouting of attributes.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
optimizer = opt.BoltOn(internal_optimizer, loss)
with self.assertRaises(AttributeError):
getattr(optimizer, attr)
class SchedulerTest(keras_parameterized.TestCase):
"""GammaBeta Scheduler tests."""
@parameterized.named_parameters([
{'testcase_name': 'not in context',
'err_msg': 'Please initialize the GammaBetaDecreasingStep Learning Rate'
' Scheduler'
}
])
def test_bad_call(self, err_msg):
"""Test attribute of internal opt correctly rerouted to the internal opt.
Args:
err_msg: The expected error message from the scheduler bad call.
"""
scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
scheduler(1)
@parameterized.named_parameters([
{'testcase_name': 'step 1',
'step': 1,
'res': 0.5},
{'testcase_name': 'step 2',
'step': 2,
'res': 0.5},
{'testcase_name': 'step 3',
'step': 3,
'res': 0.333333333},
])
def test_call(self, step, res):
"""Test call.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
step: step number to 'GammaBetaDecreasingStep' 'Scheduler'.
res: expected result from call to 'GammaBetaDecreasingStep' 'Scheduler'.
"""
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)
scheduler = opt.GammaBetaDecreasingStep()
scheduler.initialize(beta, gamma)
step = _ops.convert_to_tensor_v2(step, dtype=tf.float32)
lr = scheduler(step)
self.assertAllClose(lr.numpy(), res)
if __name__ == '__main__':
test.main()

View file

@ -0,0 +1,188 @@
# Copyright 2019, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tutorial for bolt_on module, the model and the optimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=wrong-import-position
from privacy.bolt_on import losses # pylint: disable=wrong-import-position
from privacy.bolt_on import models # pylint: disable=wrong-import-position
from privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position
# -------
# First, we will create a binary classification dataset with a single output
# dimension. The samples for each label are repeated data points at different
# points in space.
# -------
# Parameters for dataset
n_samples = 10
input_dim = 2
n_outputs = 1
# Create binary classification dataset:
x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)),
tf.constant(1, tf.float32, (n_samples, input_dim))]
y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),
tf.constant(1, tf.float32, (n_samples, 1))]
x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)
print(x.shape, y.shape)
generator = tf.data.Dataset.from_tensor_slices((x, y))
generator = generator.batch(10)
generator = generator.shuffle(10)
# -------
# First, we will explore using the pre - built BoltOnModel, which is a thin
# wrapper around a Keras Model using a single - layer neural network.
# It automatically uses the BoltOn Optimizer which encompasses all the logic
# required for the BoltOn Differential Privacy method.
# -------
bolt = models.BoltOnModel(n_outputs) # tell the model how many outputs we have.
# -------
# Now, we will pick our optimizer and Strongly Convex Loss function. The loss
# must extend from StrongConvexMixin and implement the associated methods.Some
# existing loss functions are pre - implemented in bolt_on.loss
# -------
optimizer = tf.optimizers.SGD()
reg_lambda = 1
C = 1
radius_constant = 1
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
# -------
# For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy
# to be 1; these are all tunable and their impact can be read in losses.
# StrongConvexBinaryCrossentropy.We then compile the model with the chosen
# optimizer and loss, which will automatically wrap the chosen optimizer with
# the BoltOn Optimizer, ensuring the required components function as required
# for privacy guarantees.
# -------
bolt.compile(optimizer, loss)
# -------
# To fit the model, the optimizer will require additional information about
# the dataset and model.These parameters are:
# 1. the class_weights used
# 2. the number of samples in the dataset
# 3. the batch size which the model will try to infer, if possible. If not,
# you will be required to pass these explicitly to the fit method.
#
# As well, there are two privacy parameters than can be altered:
# 1. epsilon, a float
# 2. noise_distribution, a valid string indicating the distriution to use (must
# be implemented)
#
# The BoltOnModel offers a helper method,.calculate_class_weight to aid in
# class_weight calculation.
# required parameters
# -------
class_weight = None # default, use .calculate_class_weight for other values
batch_size = None # default, if it cannot be inferred, specify this
n_samples = None # default, if it cannot be iferred, specify this
# privacy parameters
epsilon = 2
noise_distribution = 'laplace'
bolt.fit(x,
y,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
epochs=2)
# -------
# We may also train a generator object, or try different optimizers and loss
# functions. Below, we will see that we must pass the number of samples as the
# fit method is unable to infer it for a generator.
# -------
optimizer2 = tf.optimizers.Adam()
bolt.compile(optimizer2, loss)
# required parameters
class_weight = None # default, use .calculate_class_weight for other values
batch_size = None # default, if it cannot be inferred, specify this
n_samples = None # default, if it cannot be iferred, specify this
# privacy parameters
epsilon = 2
noise_distribution = 'laplace'
try:
bolt.fit(generator,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0)
except ValueError as e:
print(e)
# -------
# And now, re running with the parameter set.
# -------
n_samples = 20
bolt.fit(generator,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0)
# -------
# You don't have to use the BoltOn model to use the BoltOn method.
# There are only a few requirements:
# 1. make sure any requirements from the loss are implemented in the model.
# 2. instantiate the optimizer and use it as a context around the fit operation.
# -------
# -------------------- Part 2, using the Optimizer
# -------
# Here, we create our own model and setup the BoltOn optimizer.
# -------
class TestModel(tf.keras.Model): # pylint: disable=abstract-method
def __init__(self, reg_layer, number_of_outputs=1):
super(TestModel, self).__init__(name='test')
self.output_layer = tf.keras.layers.Dense(number_of_outputs,
kernel_regularizer=reg_layer)
def call(self, inputs): # pylint: disable=arguments-differ
return self.output_layer(inputs)
optimizer = tf.optimizers.SGD()
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
optimizer = BoltOn(optimizer, loss)
# -------
# Now, we instantiate our model and check for 1. Since our loss requires L2
# regularization over the kernel, we will pass it to the model.
# -------
n_outputs = 1 # parameter for model and optimizer context.
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
test_model.compile(optimizer, loss)
# -------
# We comply with 2., and use the BoltOn Optimizer as a context around the fit
# method.
# -------
# parameters for context
noise_distribution = 'laplace'
epsilon = 2
class_weights = 1 # Previously, the fit method auto-detected the class_weights.
# Here, we need to pass the class_weights explicitly. 1 is the same as None.
n_samples = 20
batch_size = 5
with optimizer(
noise_distribution=noise_distribution,
epsilon=epsilon,
layers=test_model.layers,
class_weights=class_weights,
n_samples=n_samples,
batch_size=batch_size
) as _:
test_model.fit(x, y, batch_size=batch_size, epochs=2)