forked from 626_privacy/tensorflow_privacy
conflicts in models test
This commit is contained in:
parent
d10d7b0148
commit
034ae8fea4
1 changed files with 2 additions and 19 deletions
|
@ -206,13 +206,8 @@ class InitTests(keras_parameterized.TestCase):
|
||||||
Args:
|
Args:
|
||||||
n_outputs: number of output neurons
|
n_outputs: number of output neurons
|
||||||
loss: instantiated TestLoss instance
|
loss: instantiated TestLoss instance
|
||||||
<<<<<<< HEAD
|
|
||||||
optimizer: instanced TestOptimizer instance
|
|
||||||
"""
|
|
||||||
=======
|
|
||||||
optimizer: instantiated TestOptimizer instance
|
optimizer: instantiated TestOptimizer instance
|
||||||
"""
|
"""
|
||||||
>>>>>>> 71c4a11eb9ad66a78fb13428987366887ea20beb
|
|
||||||
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
# test compilaton of invalid tf.optimizer and non instantiated loss.
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
with self.assertRaises((ValueError, AttributeError)):
|
with self.assertRaises((ValueError, AttributeError)):
|
||||||
|
@ -517,17 +512,6 @@ class FitTests(keras_parameterized.TestCase):
|
||||||
num_classes,
|
num_classes,
|
||||||
err_msg):
|
err_msg):
|
||||||
"""Tests the BOltonModel calculate_class_weights method.
|
"""Tests the BOltonModel calculate_class_weights method.
|
||||||
<<<<<<< HEAD
|
|
||||||
|
|
||||||
This test passes invalid params which should raise the expected errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
class_weights: the class_weights to use
|
|
||||||
class_counts: count of number of samples for each class
|
|
||||||
num_classes: number of outputs neurons
|
|
||||||
err_msg:
|
|
||||||
"""
|
|
||||||
=======
|
|
||||||
|
|
||||||
This test passes invalid params which should raise the expected errors.
|
This test passes invalid params which should raise the expected errors.
|
||||||
|
|
||||||
|
@ -536,8 +520,7 @@ class FitTests(keras_parameterized.TestCase):
|
||||||
class_counts: count of number of samples for each class.
|
class_counts: count of number of samples for each class.
|
||||||
num_classes: number of outputs neurons.
|
num_classes: number of outputs neurons.
|
||||||
err_msg: The expected error message.
|
err_msg: The expected error message.
|
||||||
"""
|
"""
|
||||||
>>>>>>> 71c4a11eb9ad66a78fb13428987366887ea20beb
|
|
||||||
clf = models.BoltonModel(1, 1)
|
clf = models.BoltonModel(1, 1)
|
||||||
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method
|
||||||
clf.calculate_class_weights(class_weights,
|
clf.calculate_class_weights(class_weights,
|
||||||
|
|
Loading…
Reference in a new issue