conflicts in models test

This commit is contained in:
npapernot 2019-07-29 21:29:03 +00:00
parent d10d7b0148
commit 034ae8fea4

View file

@ -206,13 +206,8 @@ class InitTests(keras_parameterized.TestCase):
Args: Args:
n_outputs: number of output neurons n_outputs: number of output neurons
loss: instantiated TestLoss instance loss: instantiated TestLoss instance
<<<<<<< HEAD
optimizer: instanced TestOptimizer instance
"""
=======
optimizer: instantiated TestOptimizer instance optimizer: instantiated TestOptimizer instance
""" """
>>>>>>> 71c4a11eb9ad66a78fb13428987366887ea20beb
# test compilaton of invalid tf.optimizer and non instantiated loss. # test compilaton of invalid tf.optimizer and non instantiated loss.
with self.cached_session(): with self.cached_session():
with self.assertRaises((ValueError, AttributeError)): with self.assertRaises((ValueError, AttributeError)):
@ -517,17 +512,6 @@ class FitTests(keras_parameterized.TestCase):
num_classes, num_classes,
err_msg): err_msg):
"""Tests the BOltonModel calculate_class_weights method. """Tests the BOltonModel calculate_class_weights method.
<<<<<<< HEAD
This test passes invalid params which should raise the expected errors.
Args:
class_weights: the class_weights to use
class_counts: count of number of samples for each class
num_classes: number of outputs neurons
err_msg:
"""
=======
This test passes invalid params which should raise the expected errors. This test passes invalid params which should raise the expected errors.
@ -536,8 +520,7 @@ class FitTests(keras_parameterized.TestCase):
class_counts: count of number of samples for each class. class_counts: count of number of samples for each class.
num_classes: number of outputs neurons. num_classes: number of outputs neurons.
err_msg: The expected error message. err_msg: The expected error message.
""" """
>>>>>>> 71c4a11eb9ad66a78fb13428987366887ea20beb
clf = models.BoltonModel(1, 1) clf = models.BoltonModel(1, 1)
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
clf.calculate_class_weights(class_weights, clf.calculate_class_weights(class_weights,