Merge pull request #3 from georgianpartners/bolton

Code review changes
This commit is contained in:
Christopher Choquette Choo 2019-07-16 20:02:31 -04:00 committed by GitHub
commit eab43e8294
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 681 additions and 638 deletions

View file

@ -1,3 +1,17 @@
# Copyright 2019, The TensorFlow Privacy Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Bolton Method for privacy."""
import sys import sys
from distutils.version import LooseVersion from distutils.version import LooseVersion
import tensorflow as tf import tensorflow as tf
@ -12,4 +26,4 @@ else:
from privacy.bolton.models import BoltonModel from privacy.bolton.models import BoltonModel
from privacy.bolton.optimizers import Bolton from privacy.bolton.optimizers import Bolton
from privacy.bolton.losses import StrongConvexHuber from privacy.bolton.losses import StrongConvexHuber
from privacy.bolton.losses import StrongConvexBinaryCrossentropy from privacy.bolton.losses import StrongConvexBinaryCrossentropy

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -40,53 +40,47 @@ class StrongConvexMixin:
"""Radius, R, of the hypothesis space W. """Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space. W is a convex set that forms the hypothesis space.
Returns: R Returns:
R
""" """
raise NotImplementedError("Radius not implemented for StrongConvex Loss" raise NotImplementedError("Radius not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def gamma(self): def gamma(self):
""" Strongly convexity, gamma """Returns strongly convex parameter, gamma."""
Returns: gamma
"""
raise NotImplementedError("Gamma not implemented for StrongConvex Loss" raise NotImplementedError("Gamma not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def beta(self, class_weight): def beta(self, class_weight):
"""Smoothness, beta """Smoothness, beta.
Args: Args:
class_weight: the class weights as scalar or 1d tensor, where its class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs. dimensionality is equal to the number of outputs.
Returns: Beta Returns:
Beta
""" """
raise NotImplementedError("Beta not implemented for StrongConvex Loss" raise NotImplementedError("Beta not implemented for StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
"""Lipchitz constant, L """Lipchitz constant, L.
Args: Args:
class_weight: class weights used class_weight: class weights used
Returns: L Returns: L
""" """
raise NotImplementedError("lipchitz constant not implemented for " raise NotImplementedError("lipchitz constant not implemented for "
"StrongConvex Loss" "StrongConvex Loss"
"function: %s" % str(self.__class__.__name__)) "function: %s" % str(self.__class__.__name__))
def kernel_regularizer(self): def kernel_regularizer(self):
"""returns the kernel_regularizer to be used. Any subclass should override """Returns the kernel_regularizer to be used.
this method if they want a kernel_regularizer (if required for
the loss function to be StronglyConvex
:return: None or kernel_regularizer layer Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
""" """
return None return None
@ -97,16 +91,15 @@ class StrongConvexMixin:
class_weight: class weights used class_weight: class weights used
dtype: the data type for tensor conversions. dtype: the data type for tensor conversions.
Returns: maximum class weighting as tensor scalar Returns:
maximum class weighting as tensor scalar
""" """
class_weight = _ops.convert_to_tensor_v2(class_weight, dtype) class_weight = _ops.convert_to_tensor_v2(class_weight, dtype)
return tf.math.reduce_max(class_weight) return tf.math.reduce_max(class_weight)
class StrongConvexHuber(losses.Loss, StrongConvexMixin): class StrongConvexHuber(losses.Loss, StrongConvexMixin):
"""Strong Convex version of Huber loss using l2 weight regularization. """Strong Convex version of Huber loss using l2 weight regularization."""
"""
def __init__(self, def __init__(self,
reg_lambda: float, reg_lambda: float,
@ -153,7 +146,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
) )
def call(self, y_true, y_pred): def call(self, y_true, y_pred):
"""Compute loss """Computes loss
Args: Args:
y_true: Ground truth values. One hot encoded using -1 and 1. y_true: Ground truth values. One hot encoded using -1 and 1.
@ -162,7 +155,6 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
# return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight
h = self.delta h = self.delta
z = y_pred * y_true z = y_pred * y_true
one = tf.constant(1, dtype=self.dtype) one = tf.constant(1, dtype=self.dtype)
@ -172,23 +164,18 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
return _ops.convert_to_tensor_v2(0, dtype=self.dtype) return _ops.convert_to_tensor_v2(0, dtype=self.dtype)
elif tf.math.abs(one - z) <= h: elif tf.math.abs(one - z) <= h:
return one / (four * h) * tf.math.pow(one + h - z, 2) return one / (four * h) * tf.math.pow(one + h - z, 2)
elif z < one - h: return one - z # elif: z < one - h
return one - z
raise ValueError('') # shouldn't be possible to get here.
def radius(self): def radius(self):
"""See super class. """See super class."""
"""
return self.radius_constant / self.reg_lambda return self.radius_constant / self.reg_lambda
def gamma(self): def gamma(self):
"""See super class. """See super class."""
"""
return self.reg_lambda return self.reg_lambda
def beta(self, class_weight): def beta(self, class_weight):
"""See super class. """See super class."""
"""
max_class_weight = self.max_class_weight(class_weight, self.dtype) max_class_weight = self.max_class_weight(class_weight, self.dtype)
delta = _ops.convert_to_tensor_v2(self.delta, delta = _ops.convert_to_tensor_v2(self.delta,
dtype=self.dtype dtype=self.dtype
@ -198,8 +185,7 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
self.reg_lambda self.reg_lambda
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
"""See super class. """See super class."""
"""
# if class_weight is provided, # if class_weight is provided,
# it should be a vector of the same size of number of classes # it should be a vector of the same size of number of classes
max_class_weight = self.max_class_weight(class_weight, self.dtype) max_class_weight = self.max_class_weight(class_weight, self.dtype)
@ -208,10 +194,13 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
return lc return lc
def kernel_regularizer(self): def kernel_regularizer(self):
""" """Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
l2 loss using reg_lambda as the l2 term (as desired). Required for
this loss function to be strongly convex. L2 regularization is required for this loss function to be strongly convex.
:return:
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
""" """
return L1L2(l2=self.reg_lambda/2) return L1L2(l2=self.reg_lambda/2)
@ -220,10 +209,7 @@ class StrongConvexBinaryCrossentropy(
losses.BinaryCrossentropy, losses.BinaryCrossentropy,
StrongConvexMixin StrongConvexMixin
): ):
""" """Strongly Convex BinaryCrossentropy loss using l2 weight regularization."""
Strong Convex version of BinaryCrossentropy loss using l2 weight
regularization.
"""
def __init__(self, def __init__(self,
reg_lambda: float, reg_lambda: float,
@ -239,10 +225,12 @@ class StrongConvexBinaryCrossentropy(
C: Penalty parameter C of the loss term C: Penalty parameter C of the loss term
radius_constant: constant defining the length of the radius radius_constant: constant defining the length of the radius
reduction: reduction type to use. See super class reduction: reduction type to use. See super class
from_logits: True if the input are unscaled logits. False if they are
already scaled.
label_smoothing: amount of smoothing to perform on labels label_smoothing: amount of smoothing to perform on labels
relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the
Note, the impact of this parameter's effect on privacy impact of this parameter's effect on privacy is not known and thus the
is not known and thus the default should be used. default should be used.
name: Name of the loss instance name: Name of the loss instance
dtype: tf datatype to use for tensor conversions. dtype: tf datatype to use for tensor conversions.
""" """
@ -271,49 +259,322 @@ class StrongConvexBinaryCrossentropy(
self.radius_constant = radius_constant self.radius_constant = radius_constant
def call(self, y_true, y_pred): def call(self, y_true, y_pred):
"""Compute loss """Computes loss
Args: Args:
y_true: Ground truth values. y_true: Ground truth values.
y_pred: The predicted values. y_pred: The predicted values.
Returns: Returns:
Loss values per sample. Loss values per sample.
""" """
# loss = tf.nn.sigmoid_cross_entropy_with_logits(
# labels=y_true,
# logits=y_pred
# )
loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred) loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred)
loss = loss * self.C loss = loss * self.C
return loss return loss
def radius(self): def radius(self):
"""See super class. """See super class."""
"""
return self.radius_constant / self.reg_lambda return self.radius_constant / self.reg_lambda
def gamma(self): def gamma(self):
"""See super class. """See super class."""
"""
return self.reg_lambda return self.reg_lambda
def beta(self, class_weight): def beta(self, class_weight):
"""See super class. """See super class."""
"""
max_class_weight = self.max_class_weight(class_weight, self.dtype) max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda return self.C * max_class_weight + self.reg_lambda
def lipchitz_constant(self, class_weight): def lipchitz_constant(self, class_weight):
"""See super class. """See super class."""
"""
max_class_weight = self.max_class_weight(class_weight, self.dtype) max_class_weight = self.max_class_weight(class_weight, self.dtype)
return self.C * max_class_weight + self.reg_lambda * self.radius() return self.C * max_class_weight + self.reg_lambda * self.radius()
def kernel_regularizer(self): def kernel_regularizer(self):
""" """Return l2 loss using 0.5*reg_lambda as the l2 term (as desired).
l2 loss using reg_lambda as the l2 term (as desired). Required for
this loss function to be strongly convex. L2 regularization is required for this loss function to be strongly convex.
:return:
Returns:
The L2 regularizer layer for this loss function, with regularizer constant
set to half the 0.5 * reg_lambda.
""" """
return L1L2(l2=self.reg_lambda/2) return L1L2(l2=self.reg_lambda/2)
# class StrongConvexSparseCategoricalCrossentropy(
# losses.CategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexSparseCategoricalCrossentropy, self).__init__(
# reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)
#
# class StrongConvexSparseCategoricalCrossentropy(
# losses.SparseCategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexHuber, self).__init__(reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)
#
#
# class StrongConvexCategoricalCrossentropy(
# losses.CategoricalCrossentropy,
# StrongConvexMixin
# ):
# """
# Strong Convex version of CategoricalCrossentropy loss using l2 weight
# regularization.
# """
#
# def __init__(self,
# reg_lambda: float,
# C: float,
# radius_constant: float,
# from_logits: bool = True,
# label_smoothing: float = 0,
# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
# name: str = 'binarycrossentropy',
# dtype=tf.float32):
# """
# Args:
# reg_lambda: Weight regularization constant
# C: Penalty parameter C of the loss term
# radius_constant: constant defining the length of the radius
# reduction: reduction type to use. See super class
# label_smoothing: amount of smoothing to perform on labels
# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x)
# name: Name of the loss instance
# dtype: tf datatype to use for tensor conversions.
# """
# if reg_lambda <= 0:
# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda))
# if C <= 0:
# raise ValueError('c: {0}, should be >= 0'.format(C))
# if radius_constant <= 0:
# raise ValueError('radius_constant: {0}, should be >= 0'.format(
# radius_constant
# ))
#
# self.C = C
# self.dtype = dtype
# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype)
# super(StrongConvexHuber, self).__init__(reduction=reduction,
# name=name,
# from_logits=from_logits,
# label_smoothing=label_smoothing,
# )
# self.radius_constant = radius_constant
#
# def call(self, y_true, y_pred):
# """Compute loss
#
# Args:
# y_true: Ground truth values.
# y_pred: The predicted values.
#
# Returns:
# Loss values per sample.
# """
# loss = super()
# loss = loss * self.C
# return loss
#
# def radius(self):
# """See super class.
# """
# return self.radius_constant / self.reg_lambda
#
# def gamma(self):
# """See super class.
# """
# return self.reg_lambda
#
# def beta(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda
#
# def lipchitz_constant(self, class_weight):
# """See super class.
# """
# max_class_weight = self.max_class_weight(class_weight, self.dtype)
# return self.C * max_class_weight + self.reg_lambda * self.radius()
#
# def kernel_regularizer(self):
# """
# l2 loss using reg_lambda as the l2 term (as desired). Required for
# this loss function to be strongly convex.
# :return:
# """
# return L1L2(l2=self.reg_lambda)

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from contextlib import contextmanager
from io import StringIO
import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import keras_parameterized
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -27,6 +30,18 @@ from privacy.bolton.losses import StrongConvexHuber
from privacy.bolton.losses import StrongConvexMixin from privacy.bolton.losses import StrongConvexMixin
@contextmanager
def captured_output():
"""Capture std_out and std_err within context."""
new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr
try:
sys.stdout, sys.stderr = new_out, new_err
yield sys.stdout, sys.stderr
finally:
sys.stdout, sys.stderr = old_out, old_err
class StrongConvexMixinTests(keras_parameterized.TestCase): class StrongConvexMixinTests(keras_parameterized.TestCase):
"""Tests for the StrongConvexMixin""" """Tests for the StrongConvexMixin"""
@parameterized.named_parameters([ @parameterized.named_parameters([
@ -72,7 +87,7 @@ class StrongConvexMixinTests(keras_parameterized.TestCase):
class BinaryCrossesntropyTests(keras_parameterized.TestCase): class BinaryCrossesntropyTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss""" """tests for BinaryCrossesntropy StrongConvex loss."""
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'normal', {'testcase_name': 'normal',
@ -82,7 +97,8 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
}, # pylint: disable=invalid-name }, # pylint: disable=invalid-name
]) ])
def test_init_params(self, reg_lambda, C, radius_constant): def test_init_params(self, reg_lambda, C, radius_constant):
"""Test initialization for given arguments """Test initialization for given arguments.
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg C: initialization value for C arg
@ -111,6 +127,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
]) ])
def test_bad_init_params(self, reg_lambda, C, radius_constant): def test_bad_init_params(self, reg_lambda, C, radius_constant):
"""Test invalid domain for given params. Should return ValueError """Test invalid domain for given params. Should return ValueError
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg C: initialization value for C arg
@ -146,6 +163,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
]) ])
def test_calculation(self, logits, y_true, result): def test_calculation(self, logits, y_true, result):
"""Test the call method to ensure it returns the correct value """Test the call method to ensure it returns the correct value
Args: Args:
logits: unscaled output of model logits: unscaled output of model
y_true: label y_true: label
@ -185,6 +203,7 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
]) ])
def test_fns(self, init_args, fn, args, result): def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result """Test that fn of BinaryCrossentropy loss returns the correct result
Args: Args:
init_args: init values for loss instance init_args: init values for loss instance
fn: the fn to test fn: the fn to test
@ -201,6 +220,29 @@ class BinaryCrossesntropyTests(keras_parameterized.TestCase):
result = result.l2 result = result.l2
self.assertEqual(expected, result) self.assertEqual(expected, result)
@parameterized.named_parameters([
{'testcase_name': 'label_smoothing',
'init_args': [1, 1, 1, True, 0.1],
'fn': None,
'args': None,
'print_res': 'The impact of label smoothing on privacy is unknown.'
},
])
def test_prints(self, init_args, fn, args, print_res):
"""Test logger warning from StrongConvexBinaryCrossentropy.
Args:
init_args: arguments to init the object with.
fn: function to test
args: arguments to above function
print_res: print result that should have been printed.
"""
with captured_output() as (out, err): # pylint: disable=unused-variable
loss = StrongConvexBinaryCrossentropy(*init_args)
if fn is not None:
getattr(loss, fn, lambda *arguments: print('error'))(*args)
self.assertRegexMatch(err.getvalue().strip(), [print_res])
class HuberTests(keras_parameterized.TestCase): class HuberTests(keras_parameterized.TestCase):
"""tests for BinaryCrossesntropy StrongConvex loss""" """tests for BinaryCrossesntropy StrongConvex loss"""
@ -215,6 +257,7 @@ class HuberTests(keras_parameterized.TestCase):
]) ])
def test_init_params(self, reg_lambda, c, radius_constant, delta): def test_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test initialization for given arguments """Test initialization for given arguments
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg C: initialization value for C arg
@ -244,7 +287,7 @@ class HuberTests(keras_parameterized.TestCase):
'delta': 1 'delta': 1
}, },
{'testcase_name': 'negative delta', {'testcase_name': 'negative delta',
'reg_lambda': -1, 'reg_lambda': 1,
'c': 1, 'c': 1,
'radius_constant': 1, 'radius_constant': 1,
'delta': -1 'delta': -1
@ -252,10 +295,12 @@ class HuberTests(keras_parameterized.TestCase):
]) ])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError """Test invalid domain for given params. Should return ValueError
Args: Args:
reg_lambda: initialization value for reg_lambda arg reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg C: initialization value for C arg
radius_constant: initialization value for radius_constant arg radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
""" """
# test valid domains for each variable # test valid domains for each variable
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -321,6 +366,7 @@ class HuberTests(keras_parameterized.TestCase):
]) ])
def test_calculation(self, logits, y_true, delta, result): def test_calculation(self, logits, y_true, delta, result):
"""Test the call method to ensure it returns the correct value """Test the call method to ensure it returns the correct value
Args: Args:
logits: unscaled output of model logits: unscaled output of model
y_true: label y_true: label
@ -360,6 +406,7 @@ class HuberTests(keras_parameterized.TestCase):
]) ])
def test_fns(self, init_args, fn, args, result): def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result """Test that fn of BinaryCrossentropy loss returns the correct result
Args: Args:
init_args: init values for loss instance init_args: init values for loss instance
fn: the fn to test fn: the fn to test

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -25,8 +25,11 @@ from privacy.bolton.optimizers import Bolton
class BoltonModel(Model): class BoltonModel(Model):
""" """Bolton episilon-delta differential privacy model.
Bolton episilon-delta model
The privacy guarantees are dependent on the noise that is sampled. Please
see the paper linked below for more details.
Uses 4 key steps to achieve privacy guarantees: Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation). 1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch 2. Projects weights to R after each batch
@ -121,8 +124,9 @@ class BoltonModel(Model):
noise_distribution='laplace', noise_distribution='laplace',
steps_per_epoch=None, steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ **kwargs): # pylint: disable=arguments-differ
"""Reroutes to super fit with additional Bolton delta-epsilon privacy """Reroutes to super fit with Bolton delta-epsilon privacy requirements.
requirements implemented. Note, inputs must be normalized s.t. ||x|| < 1
Note, inputs must be normalized s.t. ||x|| < 1.
Requirements are as follows: Requirements are as follows:
1. Adds noise to weights after training (output perturbation). 1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch 2. Projects weights to R after each batch
@ -139,7 +143,6 @@ class BoltonModel(Model):
whose dim == n_classes. whose dim == n_classes.
See the super method for descriptions on the rest of the arguments. See the super method for descriptions on the rest of the arguments.
""" """
if class_weight is None: if class_weight is None:
class_weight_ = self.calculate_class_weights(class_weight) class_weight_ = self.calculate_class_weights(class_weight)
@ -237,8 +240,8 @@ class BoltonModel(Model):
class_counts=None, class_counts=None,
num_classes=None num_classes=None
): ):
""" """Calculates class weighting to be used in training.
Calculates class weighting to be used in training. Can be on
Args: Args:
class_weights: str specifying type, array giving weights, or None. class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then an array of class_counts: If class_weights is not None, then an array of
@ -246,7 +249,6 @@ class BoltonModel(Model):
num_classes: If class_weights is not None, then the number of num_classes: If class_weights is not None, then the number of
classes. classes.
Returns: class_weights as 1D tensor, to be passed to model's fit method. Returns: class_weights as 1D tensor, to be passed to model's fit method.
""" """
# Value checking # Value checking
class_keys = ['balanced'] class_keys = ['balanced']

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -38,40 +38,36 @@ class TestLoss(losses.Loss, StrongConvexMixin):
self.radius_constant = radius_constant self.radius_constant = radius_constant
def radius(self): def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch) """Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns: radius Returns: radius
""" """
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def gamma(self): def gamma(self):
""" Gamma strongly convex """Returns strongly convex parameter, gamma."""
Returns: gamma
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument def beta(self, class_weight): # pylint: disable=unused-argument
"""Beta smoothess """Smoothness, beta.
Args: Args:
class_weight: the class weights used. class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns: Beta
Returns:
Beta
""" """
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
""" L lipchitz continuous """Lipchitz constant, L.
Args: Args:
class_weight: class weights used class_weight: class weights used
Returns: L Returns: L
""" """
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
@ -83,11 +79,25 @@ class TestLoss(losses.Loss, StrongConvexMixin):
) )
def max_class_weight(self, class_weight): def max_class_weight(self, class_weight):
"""the maximum weighting in class weights (max value) as a scalar tensor
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None: if class_weight is None:
return 1 return 1
raise ValueError('') raise ValueError('')
def kernel_regularizer(self): def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda) return L1L2(l2=self.reg_lambda)
@ -113,7 +123,7 @@ class TestOptimizer(OptimizerV2):
class InitTests(keras_parameterized.TestCase): class InitTests(keras_parameterized.TestCase):
"""tests for keras model initialization""" """Tests for keras model initialization."""
@parameterized.named_parameters([ @parameterized.named_parameters([
{'testcase_name': 'normal', {'testcase_name': 'normal',
@ -124,7 +134,7 @@ class InitTests(keras_parameterized.TestCase):
}, },
]) ])
def test_init_params(self, n_outputs): def test_init_params(self, n_outputs):
"""test initialization of BoltonModel """Test initialization of BoltonModel.
Args: Args:
n_outputs: number of output neurons n_outputs: number of output neurons
@ -243,8 +253,7 @@ def _do_fit(n_samples,
optimizer, optimizer,
loss, loss,
distribution='laplace'): distribution='laplace'):
"""Helper to instantiate necessary components for fitting and perform a model """Instantiate necessary components for fitting and perform a model fit.
fit.
Args: Args:
n_samples: number of samples in dataset n_samples: number of samples in dataset

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -26,11 +26,9 @@ _accepted_distributions = ['laplace'] # implemented distributions for noising
class GammaBetaDecreasingStep( class GammaBetaDecreasingStep(
optimizer_v2.learning_rate_schedule.LearningRateSchedule optimizer_v2.learning_rate_schedule.LearningRateSchedule):
): """Computes LR as minimum of 1/beta and 1/(gamma * step) at each step.
""" A required step for privacy guarantees.
Learning Rate Scheduler using the minimum of 1/beta and 1/(gamma * step)
at each step. A required step for privacy guarantees.
""" """
def __init__(self): def __init__(self):
self.is_init = False self.is_init = False
@ -38,8 +36,7 @@ class GammaBetaDecreasingStep(
self.gamma = None self.gamma = None
def __call__(self, step): def __call__(self, step):
""" """Computes and returns the learning rate.
returns the learning rate
Args: Args:
step: the current iteration number step: the current iteration number
Returns: Returns:
@ -61,15 +58,14 @@ class GammaBetaDecreasingStep(
) )
def get_config(self): def get_config(self):
""" """Return config to setup the learning rate scheduler."""
config to setup the learning rate scheduler.
"""
return {'beta': self.beta, 'gamma': self.gamma} return {'beta': self.beta, 'gamma': self.gamma}
def initialize(self, beta, gamma): def initialize(self, beta, gamma):
"""setup the learning rate scheduler with the beta and gamma values provided """Setups scheduler with beta and gamma values from the loss function.
by the loss function. Meant to be used with .fit as the loss params may
depend on values passed to fit. Meant to be used with .fit as the loss params may depend on values passed to
fit.
Args: Args:
beta: Smoothness value. See StrongConvexMixin beta: Smoothness value. See StrongConvexMixin
@ -80,37 +76,36 @@ class GammaBetaDecreasingStep(
self.gamma = gamma self.gamma = gamma
def de_initialize(self): def de_initialize(self):
"""De initialize the scheduler after fitting, in case another fit call has """De initialize post fit, as another fit call may use other parameters."""
different loss parameters.
"""
self.is_init = False self.is_init = False
self.beta = None self.beta = None
self.gamma = None self.gamma = None
class Bolton(optimizer_v2.OptimizerV2): class Bolton(optimizer_v2.OptimizerV2):
""" """Wrap another tf optimizer with Bolton privacy protocol.
Bolton optimizer wraps another tf optimizer to be used
as the visible optimizer to the tf model. No matter the optimizer
passed, "Bolton" enables the bolton model to control the learning rate
based on the strongly convex loss.
To use the Bolton method, you must: Bolton optimizer wraps another tf optimizer to be used
1. instantiate it with an instantiated tf optimizer and StrongConvexLoss. as the visible optimizer to the tf model. No matter the optimizer
2. use it as a context manager around your .fit method internals. passed, "Bolton" enables the bolton model to control the learning rate
based on the strongly convex loss.
This can be accomplished by the following: To use the Bolton method, you must:
optimizer = tf.optimizers.SGD() 1. instantiate it with an instantiated tf optimizer and StrongConvexLoss.
loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy() 2. use it as a context manager around your .fit method internals.
bolton = Bolton(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
method.
For more details on the strong convexity requirements, see: This can be accomplished by the following:
Bolt-on Differential Privacy for Scalable Stochastic Gradient optimizer = tf.optimizers.SGD()
Descent-based Analytics by Xi Wu et. al. loss = privacy.bolton.losses.StrongConvexBinaryCrossentropy()
bolton = Bolton(optimizer, loss)
with bolton(*args) as _:
model.fit()
The args required for the context manager can be found in the __call__
method.
For more details on the strong convexity requirements, see:
Bolt-on Differential Privacy for Scalable Stochastic Gradient
Descent-based Analytics by Xi Wu et. al.
""" """
def __init__(self, # pylint: disable=super-init-not-called def __init__(self, # pylint: disable=super-init-not-called
optimizer: optimizer_v2.OptimizerV2, optimizer: optimizer_v2.OptimizerV2,
@ -120,9 +115,9 @@ class Bolton(optimizer_v2.OptimizerV2):
"""Constructor. """Constructor.
Args: Args:
optimizer: Optimizer_v2 or subclass to be used as the optimizer optimizer: Optimizer_v2 or subclass to be used as the optimizer
(wrapped). (wrapped).
loss: StrongConvexLoss function that the model is being compiled with. loss: StrongConvexLoss function that the model is being compiled with.
""" """
if not isinstance(loss, StrongConvexMixin): if not isinstance(loss, StrongConvexMixin):
@ -150,19 +145,15 @@ class Bolton(optimizer_v2.OptimizerV2):
self._is_init = False self._is_init = False
def get_config(self): def get_config(self):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer.get_config() return self._internal_optimizer.get_config()
def project_weights_to_r(self, force=False): def project_weights_to_r(self, force=False):
"""helper method to normalize the weights to the R-ball. """Normalize the weights to the R-ball.
Args: Args:
force: True to normalize regardless of previous weight values. force: True to normalize regardless of previous weight values.
False to check if weights > R-ball and only normalize then. False to check if weights > R-ball and only normalize then.
Returns:
""" """
if not self._is_init: if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s ' raise Exception('This method must be called from within the optimizer\'s '
@ -186,8 +177,8 @@ class Bolton(optimizer_v2.OptimizerV2):
input_dim: the input dimensionality for the weights input_dim: the input dimensionality for the weights
output_dim the output dimensionality for the weights output_dim the output dimensionality for the weights
Returns: noise in shape of layer's weights to be added to the weights. Returns:
Noise in shape of layer's weights to be added to the weights.
""" """
if not self._is_init: if not self._is_init:
raise Exception('This method must be called from within the optimizer\'s ' raise Exception('This method must be called from within the optimizer\'s '
@ -221,8 +212,7 @@ class Bolton(optimizer_v2.OptimizerV2):
'a valid distribution'.format(distribution)) 'a valid distribution'.format(distribution))
def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer.from_config(*args, **kwargs) return self._internal_optimizer.from_config(*args, **kwargs)
def __getattr__(self, name): def __getattr__(self, name):
@ -230,11 +220,10 @@ class Bolton(optimizer_v2.OptimizerV2):
from the _internal_optimizer instance. from the _internal_optimizer instance.
Args: Args:
name: name:
Returns: attribute from Bolton if specified to come from self, else Returns: attribute from Bolton if specified to come from self, else
from _internal_optimizer. from _internal_optimizer.
""" """
if name == '_private_attributes' or name in self._private_attributes: if name == '_private_attributes' or name in self._private_attributes:
return getattr(self, name) return getattr(self, name)
@ -255,11 +244,8 @@ class Bolton(optimizer_v2.OptimizerV2):
Reroute everything else to the _internal_optimizer. Reroute everything else to the _internal_optimizer.
Args: Args:
key: attribute name key: attribute name
value: attribute value value: attribute value
Returns:
""" """
if key == '_private_attributes': if key == '_private_attributes':
object.__setattr__(self, key, value) object.__setattr__(self, key, value)
@ -269,44 +255,37 @@ class Bolton(optimizer_v2.OptimizerV2):
setattr(self._internal_optimizer, key, value) setattr(self._internal_optimizer, key, value)
def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access
def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access
def get_updates(self, loss, params): def get_updates(self, loss, params):
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
out = self._internal_optimizer.get_updates(loss, params) out = self._internal_optimizer.get_updates(loss, params)
self.project_weights_to_r() self.project_weights_to_r()
return out return out
def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
out = self._internal_optimizer.apply_gradients(*args, **kwargs) out = self._internal_optimizer.apply_gradients(*args, **kwargs)
self.project_weights_to_r() self.project_weights_to_r()
return out return out
def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
out = self._internal_optimizer.minimize(*args, **kwargs) out = self._internal_optimizer.minimize(*args, **kwargs)
self.project_weights_to_r() self.project_weights_to_r()
return out return out
def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access
def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ
"""Reroutes to _internal_optimizer. See super/_internal_optimizer. """Reroutes to _internal_optimizer. See super/_internal_optimizer."""
"""
return self._internal_optimizer.get_gradients(*args, **kwargs) return self._internal_optimizer.get_gradients(*args, **kwargs)
def __enter__(self): def __enter__(self):
@ -326,8 +305,8 @@ class Bolton(optimizer_v2.OptimizerV2):
n_samples, n_samples,
batch_size batch_size
): ):
"""Entry point from context. Accepts required values for bolton method and """Accepts required values for bolton method from context entry point.
stores them on the optimizer for use throughout fitting. Stores them on the optimizer for use throughout fitting.
Args: Args:
noise_distribution: the noise distribution to pick. noise_distribution: the noise distribution to pick.
@ -360,17 +339,15 @@ class Bolton(optimizer_v2.OptimizerV2):
def __exit__(self, *args): def __exit__(self, *args):
"""Exit call from with statement. """Exit call from with statement.
used to used to
1.reset the model and fit parameters passed to the optimizer
to enable the Bolton Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
1.reset the model and fit parameters passed to the optimizer
to enable the Bolton Privacy guarantees. These are reset to ensure
that any future calls to fit with the same instance of the optimizer
will properly error out.
2.call post-fit methods normalizing/projecting the model weights and
adding noise to the weights.
""" """
self.project_weights_to_r(True) self.project_weights_to_r(True)
for layer in self.layers: for layer in self.layers:

View file

@ -1,4 +1,4 @@
# Copyright 2018, The TensorFlow Authors. # Copyright 2019, The TensorFlow Authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -33,8 +33,7 @@ from privacy.bolton import optimizers as opt
class TestModel(Model): class TestModel(Model):
""" """Bolton episilon-delta model.
Bolton episilon-delta model
Uses 4 key steps to achieve privacy guarantees: Uses 4 key steps to achieve privacy guarantees:
1. Adds noise to weights after training (output perturbation). 1. Adds noise to weights after training (output perturbation).
2. Projects weights to R after each batch 2. Projects weights to R after each batch
@ -47,14 +46,15 @@ class TestModel(Model):
""" """
def __init__(self, n_outputs=2, input_shape=(16,), init_value=2): def __init__(self, n_outputs=2, input_shape=(16,), init_value=2):
""" """Constructor.
Args: Args:
n_outputs: number of output neurons n_outputs: number of output neurons
epsilon: level of privacy guarantee epsilon: level of privacy guarantee
noise_distribution: distribution to pull weight perturbations from noise_distribution: distribution to pull weight perturbations from
weights_initializer: initializer for weights weights_initializer: initializer for weights
seed: random seed to use seed: random seed to use
dtype: data type to use for tensors dtype: data type to use for tensors
""" """
super(TestModel, self).__init__(name='bolton', dynamic=False) super(TestModel, self).__init__(name='bolton', dynamic=False)
self.n_outputs = n_outputs self.n_outputs = n_outputs
@ -76,40 +76,36 @@ class TestLoss(losses.Loss, StrongConvexMixin):
self.radius_constant = radius_constant self.radius_constant = radius_constant
def radius(self): def radius(self):
"""Radius of R-Ball (value to normalize weights to after each batch) """Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns: radius Returns: radius
""" """
return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32) return _ops.convert_to_tensor_v2(self.radius_constant, dtype=tf.float32)
def gamma(self): def gamma(self):
""" Gamma strongly convex """Returns strongly convex parameter, gamma."""
Returns: gamma
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def beta(self, class_weight): # pylint: disable=unused-argument def beta(self, class_weight): # pylint: disable=unused-argument
"""Beta smoothess """Smoothness, beta.
Args: Args:
class_weight: the class weights used. class_weight: the class weights as scalar or 1d tensor, where its
dimensionality is equal to the number of outputs.
Returns: Beta
Returns:
Beta
""" """
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument
""" L lipchitz continuous """Lipchitz constant, L.
Args: Args:
class_weight: class weights used class_weight: class weights used
Returns: L Returns: L
""" """
return _ops.convert_to_tensor_v2(1, dtype=tf.float32) return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
@ -121,11 +117,25 @@ class TestLoss(losses.Loss, StrongConvexMixin):
) )
def max_class_weight(self, class_weight, dtype=tf.float32): def max_class_weight(self, class_weight, dtype=tf.float32):
"""the maximum weighting in class weights (max value) as a scalar tensor
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
"""
if class_weight is None: if class_weight is None:
return 1 return 1
raise NotImplementedError('') raise NotImplementedError('')
def kernel_regularizer(self): def kernel_regularizer(self):
"""Returns the kernel_regularizer to be used.
Any subclass should override this method if they want a kernel_regularizer
(if required for the loss function to be StronglyConvex.
"""
return L1L2(l2=self.reg_lambda) return L1L2(l2=self.reg_lambda)

View file

@ -1,432 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"pycharm": {
"is_executing": false
},
"scrolled": true
},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('..')\n",
"import tensorflow as tf\n",
"from privacy.bolton import losses\n",
"from privacy.bolton import models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we will create a binary classification dataset with a single output dimension.\n",
"The samples for each label are repeated datapoints at different points in space."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"is_executing": false,
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(20, 2) (20, 1)\n"
]
}
],
"source": [
"# Parameters for dataset\n",
"n_samples = 10\n",
"input_dim = 2\n",
"n_outputs = 1\n",
"# Create binary classification dataset:\n",
"x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)), \n",
" tf.constant(1, tf.float32, (n_samples, input_dim))]\n",
"y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),\n",
" tf.constant(1, tf.float32, (n_samples, 1))]\n",
"x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)\n",
"print(x.shape, y.shape)\n",
"generator = tf.data.Dataset.from_tensor_slices((x, y))\n",
"generator = generator.batch(10)\n",
"generator = generator.shuffle(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, we will explore using the pre-built BoltonModel, which is a thin wrapper around a Keras Model using a single-layer neural network. It automatically uses the Bolton Optimizer which encompasses all the logic required for the Bolton Differential Privacy method.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we will pick our optimizer and Strongly Convex Loss function. The loss must extend from StrongConvexMixin and implement the associated methods. Some existing loss functions are pre-implemented in bolton.loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"optimizer = tf.optimizers.SGD()\n",
"reg_lambda = 1\n",
"C = 1\n",
"radius_constant = 1\n",
"loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy to be 1; these are all tunable and their impact can be read in losses.StrongConvexBinaryCrossentropy. We then compile the model with the chosen optimizer and loss, which will automatically wrap the chosen optimizer with the Bolton Optimizer, ensuring the required components function as required for privacy guarantees."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"pycharm": {
"is_executing": false
}
},
"outputs": [],
"source": [
"bolt.compile(optimizer, loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To fit the model, the optimizer will require additional information about the dataset and model. These parameters are:\n",
"1. the class_weights used\n",
"2. the number of samples in the dataset\n",
"3. the batch size\n",
"which the model will try to infer, if possible. If not, you will be required to pass these explicitly to the fit method.\n",
"As well, there are two privacy parameters than can be altered: \n",
"1. epsilon, a float\n",
"2. noise_distribution, a valid string indicating the distriution to use (must be implemented)\n",
"\n",
"The BoltonModel offers a helper method, .calculate_class_weight to aid in class_weight calculation."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING: Logging before flag parsing goes to stderr.\n",
"W0619 11:00:32.392859 4467058112 deprecation.py:323] From /Users/christopherchoo/PycharmProjects/privacy/venv/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.where in 2.0, which has the same broadcast rule as np.where\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 20 samples\n",
"Epoch 1/2\n",
"20/20 [==============================] - 0s 4ms/sample - loss: 0.8146\n",
"Epoch 2/2\n",
"20/20 [==============================] - 0s 94us/sample - loss: 0.5699\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x10543d0f0>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# required parameters\n",
"class_weight = None # default, use .calculate_class_weight to specify other values\n",
"batch_size = None # default, if it cannot be inferred, specify this\n",
"n_samples = None # default, if it cannot be iferred, specify this\n",
"# privacy parameters\n",
"epsilon = 2\n",
"noise_distribution = 'laplace'\n",
"\n",
"bolt.fit(x, \n",
" y, \n",
" epsilon=epsilon, \n",
" class_weight=class_weight, \n",
" batch_size=batch_size, \n",
" n_samples=n_samples,\n",
" noise_distribution=noise_distribution,\n",
" epochs=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We may also train a generator object, or try different optimizers and loss functions. Below, we will see that we must pass the number of samples as the fit method is unable to infer it for a generator."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"optimizer2 = tf.optimizers.Adam()\n",
"bolt.compile(optimizer2, loss)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Could not infer the number of samples. Please pass this in using n_samples.\n"
]
}
],
"source": [
"# required parameters\n",
"class_weight = None # default, use .calculate_class_weight to specify other values\n",
"batch_size = None # default, if it cannot be inferred, specify this\n",
"n_samples = None # default, if it cannot be iferred, specify this\n",
"# privacy parameters\n",
"epsilon = 2\n",
"noise_distribution = 'laplace'\n",
"try:\n",
" bolt.fit(generator,\n",
" epsilon=epsilon, \n",
" class_weight=class_weight, \n",
" batch_size=batch_size, \n",
" n_samples=n_samples,\n",
" noise_distribution=noise_distribution,\n",
" verbose=0\n",
" )\n",
"except ValueError as e:\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, re running with the parameter set."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x1267db4a8>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n_samples = 20\n",
"bolt.fit(generator,\n",
" epsilon=epsilon, \n",
" class_weight=class_weight, \n",
" batch_size=batch_size, \n",
" n_samples=n_samples,\n",
" noise_distribution=noise_distribution,\n",
" verbose=0\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You don't have to use the bolton model to use the Bolton method. There are only a few requirements:\n",
"1. make sure any requirements from the loss are implemented in the model.\n",
"2. instantiate the optimizer and use it as a context around your fit operation."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from privacy.bolton.optimizers import Bolton"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we create our own model and setup the Bolton optimizer."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class TestModel(tf.keras.Model):\n",
" def __init__(self, reg_layer, n_outputs=1):\n",
" super(TestModel, self).__init__(name='test')\n",
" self.output_layer = tf.keras.layers.Dense(n_outputs,\n",
" kernel_regularizer=reg_layer\n",
" )\n",
" \n",
" def call(self, inputs):\n",
" return self.output_layer(inputs)\n",
"\n",
"optimizer = tf.optimizers.SGD()\n",
"loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)\n",
"optimizer = Bolton(optimizer, loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we instantiate our model and check for 1. Since our loss requires L2 regularization over the kernel, we will pass it to the model."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"n_outputs = 1 # parameter for model and optimizer context.\n",
"test_model = TestModel(loss.kernel_regularizer(), n_outputs)\n",
"test_model.compile(optimizer, loss)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We comply with 2., and use the Bolton Optimizer as a context around the fit method."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 20 samples\n",
"Epoch 1/2\n",
"20/20 [==============================] - 0s 3ms/sample - loss: 0.9096\n",
"Epoch 2/2\n",
"20/20 [==============================] - 0s 430us/sample - loss: 0.5275\n"
]
}
],
"source": [
"# parameters for context\n",
"noise_distribution = 'laplace'\n",
"epsilon = 2\n",
"class_weights = 1 # Previosuly, the fit method auto-detected the class_weights.\n",
"# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.\n",
"n_samples = 20\n",
"batch_size = 5\n",
"\n",
"with optimizer(\n",
" noise_distribution=noise_distribution,\n",
" epsilon=epsilon,\n",
" layers=test_model.layers,\n",
" class_weights=class_weights, \n",
" n_samples=n_samples,\n",
" batch_size=batch_size\n",
") as _:\n",
" test_model.fit(x, y, batch_size=batch_size, epochs=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"source": [],
"metadata": {
"collapsed": false
}
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View file

@ -0,0 +1,155 @@
import sys
sys.path.append('..')
import tensorflow as tf
from privacy.bolton import losses
from privacy.bolton import models
"""First, we will create a binary classification dataset with a single output
dimension. The samples for each label are repeated data points at different
points in space."""
# Parameters for dataset
n_samples = 10
input_dim = 2
n_outputs = 1
# Create binary classification dataset:
x_stack = [tf.constant(-1, tf.float32, (n_samples, input_dim)),
tf.constant(1, tf.float32, (n_samples, input_dim))]
y_stack = [tf.constant(0, tf.float32, (n_samples, 1)),
tf.constant(1, tf.float32, (n_samples, 1))]
x, y = tf.concat(x_stack, 0), tf.concat(y_stack, 0)
print(x.shape, y.shape)
generator = tf.data.Dataset.from_tensor_slices((x, y))
generator = generator.batch(10)
generator = generator.shuffle(10)
"""First, we will explore using the pre - built BoltonModel, which is a thin
wrapper around a Keras Model using a single - layer neural network.
It automatically uses the Bolton Optimizer which encompasses all the logic
required for the Bolton Differential Privacy method."""
bolt = models.BoltonModel(n_outputs) # tell the model how many outputs we have.
"""Now, we will pick our optimizer and Strongly Convex Loss function. The loss
must extend from StrongConvexMixin and implement the associated methods.Some
existing loss functions are pre - implemented in bolton.loss"""
optimizer = tf.optimizers.SGD()
reg_lambda = 1
C = 1
radius_constant = 1
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
"""For simplicity, we pick all parameters of the StrongConvexBinaryCrossentropy
to be 1; these are all tunable and their impact can be read in losses.
StrongConvexBinaryCrossentropy.We then compile the model with the chosen
optimizer and loss, which will automatically wrap the chosen optimizer with the
Bolton Optimizer, ensuring the required components function as required for
privacy guarantees."""
bolt.compile(optimizer, loss)
"""To fit the model, the optimizer will require additional information about
the dataset and model.These parameters are:
1. the class_weights used
2. the number of samples in the dataset
3. the batch size which the model will try to infer, if possible. If not, you
will be required to pass these explicitly to the fit method.
As well, there are two privacy parameters than can be altered:
1. epsilon, a float
2. noise_distribution, a valid string indicating the distriution to use (must be
implemented)
The BoltonModel offers a helper method,.calculate_class_weight to aid in
class_weight calculation."""
# required parameters
class_weight = None # default, use .calculate_class_weight to specify other values
batch_size = None # default, if it cannot be inferred, specify this
n_samples = None # default, if it cannot be iferred, specify this
# privacy parameters
epsilon = 2
noise_distribution = 'laplace'
bolt.fit(x,
y,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
epochs=2)
"""We may also train a generator object, or try different optimizers and loss
functions. Below, we will see that we must pass the number of samples as the fit
method is unable to infer it for a generator."""
optimizer2 = tf.optimizers.Adam()
bolt.compile(optimizer2, loss)
# required parameters
class_weight = None # default, use .calculate_class_weight to specify other values
batch_size = None # default, if it cannot be inferred, specify this
n_samples = None # default, if it cannot be iferred, specify this
# privacy parameters
epsilon = 2
noise_distribution = 'laplace'
try:
bolt.fit(generator,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0
)
except ValueError as e:
print(e)
"""And now, re running with the parameter set."""
n_samples = 20
bolt.fit(generator,
epsilon=epsilon,
class_weight=class_weight,
batch_size=batch_size,
n_samples=n_samples,
noise_distribution=noise_distribution,
verbose=0
)
"""You don't have to use the bolton model to use the Bolton method.
There are only a few requirements:
1. make sure any requirements from the loss are implemented in the model.
2. instantiate the optimizer and use it as a context around your fit operation.
"""
from privacy.bolton.optimizers import Bolton
"""Here, we create our own model and setup the Bolton optimizer."""
class TestModel(tf.keras.Model):
def __init__(self, reg_layer, n_outputs=1):
super(TestModel, self).__init__(name='test')
self.output_layer = tf.keras.layers.Dense(n_outputs,
kernel_regularizer=reg_layer
)
def call(self, inputs):
return self.output_layer(inputs)
optimizer = tf.optimizers.SGD()
loss = losses.StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant)
optimizer = Bolton(optimizer, loss)
"""Now, we instantiate our model and check for 1. Since our loss requires L2
regularization over the kernel, we will pass it to the model."""
n_outputs = 1 # parameter for model and optimizer context.
test_model = TestModel(loss.kernel_regularizer(), n_outputs)
test_model.compile(optimizer, loss)
"""We comply with 2., and use the Bolton Optimizer as a context around the fit
method."""
# parameters for context
noise_distribution = 'laplace'
epsilon = 2
class_weights = 1 # Previously, the fit method auto-detected the class_weights.
# Here, we need to pass the class_weights explicitly. 1 is the equivalent of None.
n_samples = 20
batch_size = 5
with optimizer(
noise_distribution=noise_distribution,
epsilon=epsilon,
layers=test_model.layers,
class_weights=class_weights,
n_samples=n_samples,
batch_size=batch_size
) as _:
test_model.fit(x, y, batch_size=batch_size, epochs=2)