update slicing test

This commit is contained in:
Liwei Song 2020-10-21 17:07:53 -04:00
parent a41d6aace7
commit 0fa87d200c

View file

@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase):
labels_test = np.array([1, 2, 0, 2])
loss_train = np.array([2, 0.25, 4, 3])
loss_test = np.array([0.5, 3.5, 7, 4.5])
entropy_train = np.array([0.4, 8, 0.6, 10])
entropy_test = np.array([15, 10.5, 4.5, 0.3])
self.input_data = AttackInputData(
logits_train=logits_train,
@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase):
labels_train=labels_train,
labels_test=labels_test,
loss_train=loss_train,
loss_test=loss_test)
loss_test=loss_test,
entropy_train=entropy_train,
entropy_test=entropy_test)
def test_slice_entire_dataset(self):
entire_dataset_slice = SingleSliceSpec()
@ -159,6 +163,12 @@ class GetSliceTest(absltest.TestCase):
self.assertTrue((output.loss_train == [2, 4]).all())
self.assertTrue((output.loss_test == [0.5]).all())
# Check entropy
self.assertLen(output.entropy_train, 2)
self.assertLen(output.entropy_test, 1)
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
self.assertTrue((output.entropy_test == [15]).all())
def test_slice_by_percentile(self):
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
output = get_slice(self.input_data, percentile_slice)