From 751eaead545d45bcc47bff7d82656b08c474b434 Mon Sep 17 00:00:00 2001 From: Christopher Choquette Choo Date: Mon, 10 Jun 2019 16:11:47 -0400 Subject: [PATCH] Working bolton model without unit tests. -- update to include pull request changes changes include: parameter renaming, changing to mixin, moving model to compile, additional tests, fixing huber loss --- privacy/bolton/__init__.py | 4 +- privacy/bolton/loss.py | 502 +++++++++++++++++++++++++------ privacy/bolton/loss_test.py | 324 +++++++++++++++++++- privacy/bolton/model.py | 89 ++++-- privacy/bolton/model_test.py | 493 +++++++++++++++++++++++++++++- privacy/bolton/optimizer.py | 56 ++-- privacy/bolton/optimizer_test.py | 173 +++++++++++ 7 files changed, 1472 insertions(+), 169 deletions(-) diff --git a/privacy/bolton/__init__.py b/privacy/bolton/__init__.py index 46bd079..67b6148 100644 --- a/privacy/bolton/__init__.py +++ b/privacy/bolton/__init__.py @@ -10,5 +10,5 @@ if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. pass else: from privacy.bolton.model import Bolton - from privacy.bolton.loss import Huber - from privacy.bolton.loss import BinaryCrossentropy \ No newline at end of file + from privacy.bolton.loss import StrongConvexHuber + from privacy.bolton.loss import StrongConvexBinaryCrossentropy \ No newline at end of file diff --git a/privacy/bolton/loss.py b/privacy/bolton/loss.py index dd5d580..5cc029a 100644 --- a/privacy/bolton/loss.py +++ b/privacy/bolton/loss.py @@ -20,56 +20,33 @@ import tensorflow as tf from tensorflow.python.keras import losses from tensorflow.python.keras.utils import losses_utils from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras.regularizers import L1L2 -class StrongConvexLoss(losses.Loss): +class StrongConvexMixin: """ - Strong Convex Loss base class for any loss function that will be used with + Strong Convex Mixin base class for any loss function that will be used with Bolton model. Subclasses must be strongly convex and implement the associated constants. They must also conform to the requirements of tf losses - (see super class) + (see super class). + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. """ - def __init__(self, - reg_lambda: float, - c: float, - radius_constant: float = 1, - reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, - name: str = None, - dtype=tf.float32, - **kwargs): - """ - Args: - reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. - radius_constant: constant defining the length of the radius - reduction: reduction type to use. See super class - name: Name of the loss instance - dtype: tf datatype to use for tensor conversions. - """ - super(StrongConvexLoss, self).__init__(reduction=reduction, - name=name, - **kwargs) - self._sample_weight = tf.Variable(initial_value=c, - trainable=False, - dtype=tf.float32) - self._reg_lambda = reg_lambda - self.radius_constant = tf.Variable(initial_value=radius_constant, - trainable=False, - dtype=tf.float32) - self.dtype = dtype def radius(self): - """Radius of R-Ball (value to normalize weights to after each batch) + """Radius, R, of the hypothesis space W. + W is a convex set that forms the hypothesis space. - Returns: radius + Returns: R """ raise NotImplementedError("Radius not implemented for StrongConvex Loss" "function: %s" % str(self.__class__.__name__)) def gamma(self): - """ Gamma strongly convex + """ Strongly convexity, gamma Returns: gamma @@ -78,7 +55,7 @@ class StrongConvexLoss(losses.Loss): "function: %s" % str(self.__class__.__name__)) def beta(self, class_weight): - """Beta smoothess + """Smoothness, beta Args: class_weight: the class weights used. @@ -90,7 +67,7 @@ class StrongConvexLoss(losses.Loss): "function: %s" % str(self.__class__.__name__)) def lipchitz_constant(self, class_weight): - """ L lipchitz continuous + """Lipchitz constant, L Args: class_weight: class weights used @@ -102,43 +79,46 @@ class StrongConvexLoss(losses.Loss): "StrongConvex Loss" "function: %s" % str(self.__class__.__name__)) - def reg_lambda(self, convert_to_tensor: bool = False): - """ returns the lambda weight regularization constant, as a tensor if - desired + def kernel_regularizer(self): + """returns the kernel_regularizer to be used. Any subclass should override + this method if they want a kernel_regularizer (if required for + the loss function to be StronglyConvex + + :return: None or kernel_regularizer layer + """ + return None + + def max_class_weight(self, class_weight, dtype): + """the maximum weighting in class weights (max value) as a scalar tensor Args: - convert_to_tensor: True to convert to tensor, False to leave as - python numeric. + class_weight: class weights used + dtype: the data type for tensor conversions. - Returns: reg_lambda + Returns: maximum class weighting as tensor scalar """ - if convert_to_tensor: - return _ops.convert_to_tensor_v2(self._reg_lambda, dtype=self.dtype) - return self._reg_lambda - - def max_class_weight(self, class_weight): - class_weight = _ops.convert_to_tensor_v2(class_weight, dtype=self.dtype) + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype) return tf.math.reduce_max(class_weight) -class Huber(StrongConvexLoss, losses.Huber): - """Strong Convex version of huber loss using l2 weight regularization. +class StrongConvexHuber(losses.Huber, StrongConvexMixin): + """Strong Convex version of Huber loss using l2 weight regularization. """ + def __init__(self, reg_lambda: float, - c: float, + C: float, radius_constant: float, delta: float, reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, name: str = 'huber', dtype=tf.float32): - """Constructor. Passes arguments to StrongConvexLoss and Huber Loss. + """Constructor. Args: reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. + C: Penalty parameter C of the loss term radius_constant: constant defining the length of the radius delta: delta value in huber loss. When to switch from quadratic to absolute deviation. @@ -149,15 +129,22 @@ class Huber(StrongConvexLoss, losses.Huber): Returns: Loss values per sample. """ - # self.delta = tf.Variable(initial_value=delta, trainable=False) - super(Huber, self).__init__( - reg_lambda, - c, - radius_constant, + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + self.C = C + self.radius_constant = radius_constant + self.dtype = dtype + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexHuber, self).__init__( delta=delta, name=name, reduction=reduction, - dtype=dtype ) def call(self, y_true, y_pred): @@ -170,46 +157,73 @@ class Huber(StrongConvexLoss, losses.Huber): Returns: Loss values per sample. """ - return super(Huber, self).call(y_true, y_pred, **self._fn_kwargs) * \ - self._sample_weight + # return super(StrongConvexHuber, self).call(y_true, y_pred) * self._sample_weight + h = self._fn_kwargs['delta'] + z = y_pred * y_true + one = tf.constant(1, dtype=self.dtype) + four = tf.constant(4, dtype=self.dtype) + + if z > one + h: + return z - z + elif tf.math.abs(one - z) <= h: + return one / (four * h) * tf.math.pow(one + h - z, 2) + elif z < one - h: + return one - z + else: + raise ValueError('') def radius(self): """See super class. """ - return self.radius_constant / self.reg_lambda(True) + return self.radius_constant / self.reg_lambda def gamma(self): """See super class. """ - return self.reg_lambda(True) + return self.reg_lambda def beta(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight / \ - (self.delta * tf.Variable(initial_value=2, trainable=False)) + \ - self.reg_lambda(True) + max_class_weight = self.max_class_weight(class_weight, self.dtype) + delta = _ops.convert_to_tensor_v2(self._fn_kwargs['delta'], + dtype=self.dtype + ) + return self.C * max_class_weight / (delta * + tf.constant(2, dtype=self.dtype)) + \ + self.reg_lambda def lipchitz_constant(self, class_weight): """See super class. """ # if class_weight is provided, # it should be a vector of the same size of number of classes - max_class_weight = self.max_class_weight(class_weight) - lc = self._sample_weight * max_class_weight + \ - self.reg_lambda(True) * self.radius() + max_class_weight = self.max_class_weight(class_weight, self.dtype) + lc = self.C * max_class_weight + \ + self.reg_lambda * self.radius() return lc + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda) -class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): + +class StrongConvexBinaryCrossentropy( + losses.BinaryCrossentropy, + StrongConvexMixin +): """ Strong Convex version of BinaryCrossentropy loss using l2 weight regularization. """ + def __init__(self, reg_lambda: float, - c: float, + C: float, radius_constant: float, from_logits: bool = True, label_smoothing: float = 0, @@ -219,8 +233,7 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): """ Args: reg_lambda: Weight regularization constant - c: Additional constant for strongly convex convergence. Acts - as a global weight. + C: Penalty parameter C of the loss term radius_constant: constant defining the length of the radius reduction: reduction type to use. See super class label_smoothing: amount of smoothing to perform on labels @@ -228,15 +241,23 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): name: Name of the loss instance dtype: tf datatype to use for tensor conversions. """ - super(BinaryCrossentropy, self).__init__(reg_lambda, - c, - radius_constant, - reduction=reduction, - name=name, - from_logits=from_logits, - label_smoothing=label_smoothing, - dtype=dtype - ) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if C <= 0: + raise ValueError('c: {0}, should be >= 0'.format(C)) + if radius_constant <= 0: + raise ValueError('radius_constant: {0}, should be >= 0'.format( + radius_constant + )) + self.dtype = dtype + self.C = C + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexBinaryCrossentropy, self).__init__( + reduction=reduction, + name=name, + from_logits=from_logits, + label_smoothing=label_smoothing, + ) self.radius_constant = radius_constant def call(self, y_true, y_pred): @@ -249,32 +270,319 @@ class BinaryCrossentropy(StrongConvexLoss, losses.BinaryCrossentropy): Returns: Loss values per sample. """ - loss = tf.nn.sigmoid_cross_entropy_with_logits( - labels=y_true, - logits=y_pred - ) - loss = loss * self._sample_weight + # loss = tf.nn.sigmoid_cross_entropy_with_logits( + # labels=y_true, + # logits=y_pred + # ) + loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred) + loss = loss * self.C return loss def radius(self): """See super class. """ - return self.radius_constant / self.reg_lambda(True) + return self.radius_constant / self.reg_lambda def gamma(self): """See super class. """ - return self.reg_lambda(True) + return self.reg_lambda def beta(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight + self.reg_lambda(True) + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda def lipchitz_constant(self, class_weight): """See super class. """ - max_class_weight = self.max_class_weight(class_weight) - return self._sample_weight * max_class_weight + \ - self.reg_lambda(True) * self.radius() + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda * self.radius() + + def kernel_regularizer(self): + """ + l2 loss using reg_lambda as the l2 term (as desired). Required for + this loss function to be strongly convex. + :return: + """ + return L1L2(l2=self.reg_lambda) + + +# class StrongConvexSparseCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexSparseCategoricalCrossentropy, self).__init__( +# reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# class StrongConvexSparseCategoricalCrossentropy( +# losses.SparseCategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of SparseCategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) +# +# +# class StrongConvexCategoricalCrossentropy( +# losses.CategoricalCrossentropy, +# StrongConvexMixin +# ): +# """ +# Strong Convex version of CategoricalCrossentropy loss using l2 weight +# regularization. +# """ +# +# def __init__(self, +# reg_lambda: float, +# C: float, +# radius_constant: float, +# from_logits: bool = True, +# label_smoothing: float = 0, +# reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, +# name: str = 'binarycrossentropy', +# dtype=tf.float32): +# """ +# Args: +# reg_lambda: Weight regularization constant +# C: Penalty parameter C of the loss term +# radius_constant: constant defining the length of the radius +# reduction: reduction type to use. See super class +# label_smoothing: amount of smoothing to perform on labels +# relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x) +# name: Name of the loss instance +# dtype: tf datatype to use for tensor conversions. +# """ +# if reg_lambda <= 0: +# raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) +# if C <= 0: +# raise ValueError('c: {0}, should be >= 0'.format(C)) +# if radius_constant <= 0: +# raise ValueError('radius_constant: {0}, should be >= 0'.format( +# radius_constant +# )) +# +# self.C = C +# self.dtype = dtype +# self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) +# super(StrongConvexHuber, self).__init__(reduction=reduction, +# name=name, +# from_logits=from_logits, +# label_smoothing=label_smoothing, +# ) +# self.radius_constant = radius_constant +# +# def call(self, y_true, y_pred): +# """Compute loss +# +# Args: +# y_true: Ground truth values. +# y_pred: The predicted values. +# +# Returns: +# Loss values per sample. +# """ +# loss = super() +# loss = loss * self.C +# return loss +# +# def radius(self): +# """See super class. +# """ +# return self.radius_constant / self.reg_lambda +# +# def gamma(self): +# """See super class. +# """ +# return self.reg_lambda +# +# def beta(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda +# +# def lipchitz_constant(self, class_weight): +# """See super class. +# """ +# max_class_weight = self.max_class_weight(class_weight, self.dtype) +# return self.C * max_class_weight + self.reg_lambda * self.radius() +# +# def kernel_regularizer(self): +# """ +# l2 loss using reg_lambda as the l2 term (as desired). Required for +# this loss function to be strongly convex. +# :return: +# """ +# return L1L2(l2=self.reg_lambda) diff --git a/privacy/bolton/loss_test.py b/privacy/bolton/loss_test.py index 87669fd..bb7dc53 100644 --- a/privacy/bolton/loss_test.py +++ b/privacy/bolton/loss_test.py @@ -1,3 +1,325 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for loss.py""" + from __future__ import absolute_import from __future__ import division -from __future__ import print_function \ No newline at end of file +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.platform import test +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.keras.optimizer_v2 import adagrad +from tensorflow.python.keras.optimizer_v2 import gradient_descent +from tensorflow.python.keras import losses +from tensorflow.python.framework import test_util +from privacy.bolton import model +from privacy.bolton.loss import StrongConvexBinaryCrossentropy +from privacy.bolton.loss import StrongConvexHuber +from privacy.bolton.loss import StrongConvexMixin +from absl.testing import parameterized +from absl.testing import absltest +from tensorflow.python.keras.regularizers import L1L2 + + +class StrongConvexTests(keras_parameterized.TestCase): + @parameterized.named_parameters([ + {'testcase_name': 'beta not implemented', + 'fn': 'beta', + 'args': [1]}, + {'testcase_name': 'gamma not implemented', + 'fn': 'gamma', + 'args': []}, + {'testcase_name': 'lipchitz not implemented', + 'fn': 'lipchitz_constant', + 'args': [1]}, + {'testcase_name': 'radius not implemented', + 'fn': 'radius', + 'args': []}, + ]) + def test_not_implemented(self, fn, args): + with self.assertRaises(NotImplementedError): + loss = StrongConvexMixin() + getattr(loss, fn, None)(*args) + + @parameterized.named_parameters([ + {'testcase_name': 'radius not implemented', + 'fn': 'kernel_regularizer', + 'args': []}, + ]) + def test_return_none(self, fn, args): + loss = StrongConvexMixin() + ret = getattr(loss, fn, None)(*args) + self.assertEqual(ret, None) + + +class BinaryCrossesntropyTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1 + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant): + # test valid domains for each variable + loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant): + # test valid domains for each variable + with self.assertRaises(ValueError): + loss = StrongConvexBinaryCrossentropy(reg_lambda, c, radius_constant) + + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + # [] for compatibility with tensorflow loss calculation + {'testcase_name': 'both positive', + 'logits': [10000], + 'y_true': [1], + 'result': 0, + }, + {'testcase_name': 'positive gradient negative logits', + 'logits': [-10000], + 'y_true': [1], + 'result': 10000, + }, + {'testcase_name': 'positivee gradient positive logits', + 'logits': [10000], + 'y_true': [0], + 'result': 10000, + }, + {'testcase_name': 'both negative', + 'logits': [-10000], + 'y_true': [0], + 'result': 0 + }, + ]) + def test_calculation(self, logits, y_true, result): + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) + loss = loss(y_true, logits) + self.assertEqual(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1], + 'args': [], + 'result': tf.constant(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1], + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1], + 'args': [], + 'result': L1L2(l2=1), + }, + ]) + def test_fns(self, init_args, fn, args, result): + loss = StrongConvexBinaryCrossentropy(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +class HuberTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1, + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant, delta): + # test valid domains for each variable + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + self.assertIsInstance(loss, StrongConvexHuber) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1, + 'delta': 1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative delta', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': -1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): + # test valid domains for each variable + with self.assertRaises(ValueError): + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + + # test the bounds and test varied delta's + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', + 'logits': 2.1, + 'y_true': 1, + 'delta': 1, + 'result': 0, + }, + {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', + 'logits': 1.9, + 'y_true': 1, + 'delta': 1, + 'result': 0.01*0.25, + }, + {'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary', + 'logits': 0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.9**2 * 0.25, + }, + {'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary', + 'logits': -0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.1, + }, + {'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary', + 'logits': 3.1, + 'y_true': 1, + 'delta': 2, + 'result': 0, + }, + {'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary', + 'logits': 2.9, + 'y_true': 1, + 'delta': 2, + 'result': 0.01*0.125, + }, + {'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary', + 'logits': 1.1, + 'y_true': 1, + 'delta': 2, + 'result': 1.9**2 * 0.125, + }, + {'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary', + 'logits': -1.1, + 'y_true': 1, + 'delta': 2, + 'result': 2.1, + }, + {'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary', + 'logits': -2.1, + 'y_true': -1, + 'delta': 1, + 'result': 0, + }, + ]) + def test_calculation(self, logits, y_true, delta, result): + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexHuber(0.00001, 1, 1, delta) + loss = loss(y_true, logits) + self.assertAllClose(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.Variable(1.5, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': tf.Variable(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1, 1], + 'args': [1], + 'result': tf.Variable(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': L1L2(l2=1), + }, + ]) + def test_fns(self, init_args, fn, args, result): + loss = StrongConvexHuber(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +if __name__ == '__main__': + tf.test.main() \ No newline at end of file diff --git a/privacy/bolton/model.py b/privacy/bolton/model.py index a600374..78ceb7c 100644 --- a/privacy/bolton/model.py +++ b/privacy/bolton/model.py @@ -19,11 +19,12 @@ from __future__ import print_function import tensorflow as tf from tensorflow.python.keras.models import Model from tensorflow.python.keras import optimizers -from tensorflow.python.training.tracking import base as trackable from tensorflow.python.framework import ops as _ops -from privacy.bolton.loss import StrongConvexLoss +from privacy.bolton.loss import StrongConvexMixin from privacy.bolton.optimizer import Private +_accepted_distributions = ['laplace'] + class Bolton(Model): """ @@ -33,12 +34,16 @@ class Bolton(Model): 2. Projects weights to R after each batch 3. Limits learning rate 4. Use a strongly convex loss function (see compile) + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. """ + def __init__(self, n_classes, epsilon, noise_distribution='laplace', - weights_initializer=tf.initializers.GlorotUniform(), seed=1, dtype=tf.float32 ): @@ -59,6 +64,7 @@ class Bolton(Model): 2. Projects weights to R after each batch 3. Limits learning rate """ + def on_train_batch_end(self, batch, logs=None): loss = self.model.loss self.model.optimizer.limit_learning_rate( @@ -72,13 +78,17 @@ class Bolton(Model): loss = self.model.loss self.model._project_weights_to_r(loss.radius(), True) + if epsilon <= 0: + raise ValueError('Detected epsilon: {0}. ' + 'Valid range is 0 < epsilon