update slicing test
This commit is contained in:
parent
a41d6aace7
commit
0fa87d200c
1 changed files with 11 additions and 1 deletions
|
@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase):
|
|||
labels_test = np.array([1, 2, 0, 2])
|
||||
loss_train = np.array([2, 0.25, 4, 3])
|
||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||
entropy_train = np.array([0.4, 8, 0.6, 10])
|
||||
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
||||
|
||||
self.input_data = AttackInputData(
|
||||
logits_train=logits_train,
|
||||
|
@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase):
|
|||
labels_train=labels_train,
|
||||
labels_test=labels_test,
|
||||
loss_train=loss_train,
|
||||
loss_test=loss_test)
|
||||
loss_test=loss_test,
|
||||
entropy_train=entropy_train,
|
||||
entropy_test=entropy_test)
|
||||
|
||||
def test_slice_entire_dataset(self):
|
||||
entire_dataset_slice = SingleSliceSpec()
|
||||
|
@ -159,6 +163,12 @@ class GetSliceTest(absltest.TestCase):
|
|||
self.assertTrue((output.loss_train == [2, 4]).all())
|
||||
self.assertTrue((output.loss_test == [0.5]).all())
|
||||
|
||||
# Check entropy
|
||||
self.assertLen(output.entropy_train, 2)
|
||||
self.assertLen(output.entropy_test, 1)
|
||||
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
|
||||
self.assertTrue((output.entropy_test == [15]).all())
|
||||
|
||||
def test_slice_by_percentile(self):
|
||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||
output = get_slice(self.input_data, percentile_slice)
|
||||
|
|
Loading…
Reference in a new issue