more fixes

This commit is contained in:
npapernot 2019-07-25 16:13:32 +00:00
parent 8e6bcf9b4a
commit 8974a95b9a
8 changed files with 114 additions and 107 deletions

View file

@ -16,7 +16,7 @@ import sys
from distutils.version import LooseVersion
import tensorflow as tf
if LooseVersion(tf.__version__) < LooseVersion('2.0.0'):
if LooseVersion(tf.__version__) < LooseVersion("2.0.0"):
raise ImportError("Please upgrade your version "
"of tensorflow from: {0} to at least 2.0.0 to "
"use privacy/bolton".format(LooseVersion(tf.__version__)))

View file

@ -102,11 +102,11 @@ class StrongConvexHuber(losses.Loss, StrongConvexMixin):
"""Strong Convex version of Huber loss using l2 weight regularization."""
def __init__(self,
reg_lambda: float,
C: float,
radius_constant: float,
delta: float,
reduction: str = losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
reg_lambda,
C,
radius_constant,
delta,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
dtype=tf.float32):
"""Constructor.

View file

@ -261,8 +261,9 @@ class HuberTests(keras_parameterized.TestCase):
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
# test valid domains for each variable
loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta)
@ -295,11 +296,11 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_bad_init_params(self, reg_lambda, c, radius_constant, delta):
"""Test invalid domain for given params. Should return ValueError
"""Test invalid domain for given params. Should return ValueError.
Args:
reg_lambda: initialization value for reg_lambda arg
C: initialization value for C arg
c: initialization value for C arg
radius_constant: initialization value for radius_constant arg
delta: the delta parameter for the huber loss
"""
@ -406,7 +407,7 @@ class HuberTests(keras_parameterized.TestCase):
},
])
def test_fns(self, init_args, fn, args, result):
"""Test that fn of BinaryCrossentropy loss returns the correct result
"""Test that fn of BinaryCrossentropy loss returns the correct result.
Args:
init_args: init values for loss instance

View file

@ -86,10 +86,12 @@ class BoltonModel(Model): # pylint: disable=abstract-method
**kwargs): # pylint: disable=arguments-differ
"""See super class. Default optimizer used in Bolton method is SGD.
Missing args.
"""
if not isinstance(loss, StrongConvexMixin):
raise ValueError("loss function must be a Strongly Convex and therefore "
"extend the StrongConvexMixin.")
raise ValueError('loss function must be a Strongly Convex and therefore '
'extend the StrongConvexMixin.')
if not self._layers_instantiated: # compile may be called multiple times
# for instance, if the input/outputs are not defined until fit.
self.output_layer = tf.keras.layers.Dense(
@ -150,7 +152,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
data_size = n_samples
elif hasattr(x, 'shape'):
data_size = x.shape[0]
elif hasattr(x, "__len__"):
elif hasattr(x, '__len__'):
data_size = len(x)
else:
data_size = None
@ -187,10 +189,12 @@ class BoltonModel(Model): # pylint: disable=abstract-method
n_samples=None,
steps_per_epoch=None,
**kwargs): # pylint: disable=arguments-differ
"""
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
"""Fit with a generator..
This method is the same as fit except for when the passed dataset
is a generator. See super method and fit for more details.
Args:
n_samples: number of individual samples in x
noise_distribution: the distribution to get noise from.
epsilon: privacy parameter, which trades off utility and privacy. See
@ -206,7 +210,7 @@ class BoltonModel(Model): # pylint: disable=abstract-method
data_size = n_samples
elif hasattr(generator, 'shape'):
data_size = generator.shape[0]
elif hasattr(generator, "__len__"):
elif hasattr(generator, '__len__'):
data_size = len(generator)
else:
data_size = None
@ -232,13 +236,14 @@ class BoltonModel(Model): # pylint: disable=abstract-method
num_classes=None):
"""Calculates class weighting to be used in training.
Args:
Args:
class_weights: str specifying type, array giving weights, or None.
class_counts: If class_weights is not None, then an array of
the number of samples for each class
num_classes: If class_weights is not None, then the number of
classes.
Returns: class_weights as 1D tensor, to be passed to model's fit method.
Returns:
class_weights as 1D tensor, to be passed to model's fit method.
"""
# Value checking
class_keys = ['balanced']
@ -246,14 +251,14 @@ class BoltonModel(Model): # pylint: disable=abstract-method
if isinstance(class_weights, str):
is_string = True
if class_weights not in class_keys:
raise ValueError("Detected string class_weights with "
"value: {0}, which is not one of {1}."
"Please select a valid class_weight type"
"or pass an array".format(class_weights,
raise ValueError('Detected string class_weights with '
'value: {0}, which is not one of {1}.'
'Please select a valid class_weight type'
'or pass an array'.format(class_weights,
class_keys))
if class_counts is None:
raise ValueError("Class counts must be provided if using "
"class_weights=%s" % class_weights)
raise ValueError('Class counts must be provided if using '
'class_weights=%s' % class_weights)
class_counts_shape = tf.Variable(class_counts,
trainable=False,
dtype=self._dtype).shape
@ -261,12 +266,12 @@ class BoltonModel(Model): # pylint: disable=abstract-method
raise ValueError('class counts must be a 1D array.'
'Detected: {0}'.format(class_counts_shape))
if num_classes is None:
raise ValueError("num_classes must be provided if using "
"class_weights=%s" % class_weights)
raise ValueError('num_classes must be provided if using '
'class_weights=%s' % class_weights)
elif class_weights is not None:
if num_classes is None:
raise ValueError("You must pass a value for num_classes if "
"creating an array of class_weights")
raise ValueError('You must pass a value for num_classes if '
'creating an array of class_weights')
# performing class weight calculation
if class_weights is None:
class_weights = 1
@ -280,11 +285,11 @@ class BoltonModel(Model): # pylint: disable=abstract-method
else:
class_weights = _ops.convert_to_tensor_v2(class_weights)
if len(class_weights.shape) != 1:
raise ValueError("Detected class_weights shape: {0} instead of "
"1D array".format(class_weights.shape))
raise ValueError('Detected class_weights shape: {0} instead of '
'1D array'.format(class_weights.shape))
if class_weights.shape[0] != num_classes:
raise ValueError(
"Detected array length: {0} instead of: {1}".format(
'Detected array length: {0} instead of: {1}'.format(
class_weights.shape[0],
num_classes))
return class_weights

View file

@ -17,17 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras import losses
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras.regularizers import L1L2
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.framework import ops as _ops
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import losses
from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2
from tensorflow.python.keras.regularizers import L1L2
from privacy.bolton import models
from privacy.bolton.optimizers import Bolton
from privacy.bolton.losses import StrongConvexMixin
from privacy.bolton.optimizers import Bolton
class TestLoss(losses.Loss, StrongConvexMixin):
@ -41,9 +40,11 @@ class TestLoss(losses.Loss, StrongConvexMixin):
def radius(self):
"""Radius, R, of the hypothesis space W.
W is a convex set that forms the hypothesis space.
Returns: radius
Returns:
radius
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
@ -69,7 +70,8 @@ class TestLoss(losses.Loss, StrongConvexMixin):
Args:
class_weight: class weights used
Returns: L
Returns:
L
"""
return _ops.convert_to_tensor_v2(1, dtype=tf.float32)
@ -81,11 +83,10 @@ class TestLoss(losses.Loss, StrongConvexMixin):
)
def max_class_weight(self, class_weight):
"""the maximum weighting in class weights (max value) as a scalar tensor
"""the maximum weighting in class weights (max value) as a scalar tensor.
Args:
class_weight: class weights used
dtype: the data type for tensor conversions.
Returns:
maximum class weighting as tensor scalar
@ -104,7 +105,7 @@ class TestLoss(losses.Loss, StrongConvexMixin):
class TestOptimizer(OptimizerV2):
"""Test optimizer used for testing Bolton model"""
"""Test optimizer used for testing Bolton model."""
def __init__(self):
super(TestOptimizer, self).__init__('test')
@ -152,7 +153,7 @@ class InitTests(keras_parameterized.TestCase):
},
])
def test_bad_init_params(self, n_outputs):
"""test bad initializations of BoltonModel that should raise errors
"""test bad initializations of BoltonModel that should raise errors.
Args:
n_outputs: number of output neurons
@ -174,12 +175,12 @@ class InitTests(keras_parameterized.TestCase):
},
])
def test_compile(self, n_outputs, loss, optimizer):
"""test compilation of BoltonModel
"""test compilation of BoltonModel.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
"""
# test compilation of valid tf.optimizer and tf.loss
with self.cached_session():
@ -200,12 +201,12 @@ class InitTests(keras_parameterized.TestCase):
}
])
def test_bad_compile(self, n_outputs, loss, optimizer):
"""test bad compilations of BoltonModel that should raise errors
"""test bad compilations of BoltonModel that should raise errors.
Args:
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
n_outputs: number of output neurons
loss: instantiated TestLoss instance
optimizer: instanced TestOptimizer instance
"""
# test compilaton of invalid tf.optimizer and non instantiated loss.
with self.cached_session():
@ -215,17 +216,18 @@ class InitTests(keras_parameterized.TestCase):
def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
"""
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
Will have evenly split samples across each output class.
Each output class will be a different point in the input space.
"""Creates a categorically encoded dataset.
Creates a categorically encoded dataset (y is categorical).
returns the specified dataset either as a static array or as a generator.
Will have evenly split samples across each output class.
Each output class will be a different point in the input space.
Args:
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
generator: False for array, True for generator
n_samples: number of rows
input_dim: input dimensionality
n_classes: output dimensionality
generator: False for array, True for generator
Returns:
X as (n_samples, input_dim), Y as (n_samples, n_outputs)
"""
@ -246,6 +248,7 @@ def _cat_dataset(n_samples, input_dim, n_classes, generator=False):
return dataset
return x_set, y_set
def _do_fit(n_samples,
input_dim,
n_outputs,
@ -301,7 +304,7 @@ def _do_fit(n_samples,
class FitTests(keras_parameterized.TestCase):
"""Test cases for keras model fitting"""
"""Test cases for keras model fitting."""
# @test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters([
@ -323,7 +326,7 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_fit(self, generator, reset_n_samples):
"""Tests fitting of BoltonModel
"""Tests fitting of BoltonModel.
Args:
generator: True for generator test, False for iterator test.
@ -355,7 +358,7 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_fit_gen(self, generator):
"""Tests the fit_generator method of BoltonModel
"""Tests the fit_generator method of BoltonModel.
Args:
generator: True to test with a generator dataset
@ -392,7 +395,7 @@ class FitTests(keras_parameterized.TestCase):
},
])
def test_bad_fit(self, generator, reset_n_samples, distribution):
"""Tests fitting with invalid parameters, which should raise an error
"""Tests fitting with invalid parameters, which should raise an error.
Args:
generator: True to test with generator, False is iterator
@ -442,9 +445,8 @@ class FitTests(keras_parameterized.TestCase):
class_weights,
class_counts,
num_classes,
result
):
"""Tests the BOltonModel calculate_class_weights method
result):
"""Tests the BOltonModel calculate_class_weights method.
Args:
class_weights: the class_weights to use
@ -496,26 +498,28 @@ class FitTests(keras_parameterized.TestCase):
'class_weights': [[1], [1]],
'class_counts': None,
'num_classes': 2,
'err_msg': "Detected class_weights shape"},
'err_msg': 'Detected class_weights shape'},
{'testcase_name': 'class counts array, wrong number classes',
'class_weights': [1, 1, 1],
'class_counts': None,
'num_classes': 2,
'err_msg': 'Detected array length:'},
])
def test_class_errors(self,
class_weights,
class_counts,
num_classes,
err_msg):
"""Tests the BOltonModel calculate_class_weights method with invalid params
which should raise the expected errors.
"""Tests the BOltonModel calculate_class_weights method.
This test passes invalid params which should raise the expected errors.
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
result: expected result
err_msg:
"""
clf = models.BoltonModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method

View file

@ -108,8 +108,8 @@ class Bolton(optimizer_v2.OptimizerV2):
Descent-based Analytics by Xi Wu et. al.
"""
def __init__(self, # pylint: disable=super-init-not-called
optimizer: optimizer_v2.OptimizerV2,
loss: StrongConvexMixin,
optimizer,
loss,
dtype=tf.float32,
):
"""Constructor.

View file

@ -263,12 +263,7 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
def test_project(self, r, shape, n_out, init_value, result):
"""test that a fn of Bolton optimizer is working as expected.
Args:
fn: method of Optimizer to test
args: args to optimizer fn
result: the expected result
test_attr: None if the fn returns the test result. Otherwise, this is
the attribute of Bolton to check against result with.
Missing args:
"""
tf.random.set_seed(1)
@ -455,12 +450,14 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'args': [1, 1]},
])
def test_not_reroute_fn(self, fn, args):
"""Test that a fn that should not be rerouted to the internal optimizer is
in face not rerouted.
"""Test function is not rerouted.
Test that a fn that should not be rerouted to the internal optimizer is
in fact not rerouted.
Args:
fn: fn to test
args: arguments to that fn
Args:
fn: fn to test
args: arguments to that fn
"""
@tf.function
def test_run(fn, args):
@ -492,12 +489,13 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'attr': '_iterations'}
])
def test_reroute_attr(self, attr):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
"""Test a function is rerouted.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer.
Args:
attr: attribute to test
result: result after checking attribute
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
@ -510,12 +508,13 @@ class BoltonOptimizerTest(keras_parameterized.TestCase):
'attr': '_not_valid'}
])
def test_attribute_error(self, attr):
"""Test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
"""Test rerouting of attributes.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
Args:
attr: attribute to test
"""
loss = TestLoss(1, 1, 1)
internal_optimizer = TestOptimizer()
@ -537,9 +536,7 @@ class SchedulerTest(keras_parameterized.TestCase):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
Missing args
"""
scheduler = opt.GammaBetaDecreasingStep()
with self.assertRaisesRegexp(Exception, err_msg): # pylint: disable=deprecated-method
@ -557,12 +554,12 @@ class SchedulerTest(keras_parameterized.TestCase):
'res': 0.333333333},
])
def test_call(self, step, res):
""" test that attribute of internal optimizer is correctly rerouted to
the internal optimizer
"""Test call.
Test that attribute of internal optimizer is correctly rerouted to the
internal optimizer
Args:
attr: attribute to test
result: result after checking attribute
Missing Args:
"""
beta = _ops.convert_to_tensor_v2(2, dtype=tf.float32)
gamma = _ops.convert_to_tensor_v2(1, dtype=tf.float32)

View file

@ -116,7 +116,7 @@ try:
noise_distribution=noise_distribution,
verbose=0)
except ValueError as e:
print(e)
print e
# -------
# And now, re running with the parameter set.
# -------