forked from 626_privacy/tensorflow_privacy
update slicing test
This commit is contained in:
parent
a41d6aace7
commit
0fa87d200c
1 changed files with 11 additions and 1 deletions
|
@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase):
|
||||||
labels_test = np.array([1, 2, 0, 2])
|
labels_test = np.array([1, 2, 0, 2])
|
||||||
loss_train = np.array([2, 0.25, 4, 3])
|
loss_train = np.array([2, 0.25, 4, 3])
|
||||||
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
loss_test = np.array([0.5, 3.5, 7, 4.5])
|
||||||
|
entropy_train = np.array([0.4, 8, 0.6, 10])
|
||||||
|
entropy_test = np.array([15, 10.5, 4.5, 0.3])
|
||||||
|
|
||||||
self.input_data = AttackInputData(
|
self.input_data = AttackInputData(
|
||||||
logits_train=logits_train,
|
logits_train=logits_train,
|
||||||
|
@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase):
|
||||||
labels_train=labels_train,
|
labels_train=labels_train,
|
||||||
labels_test=labels_test,
|
labels_test=labels_test,
|
||||||
loss_train=loss_train,
|
loss_train=loss_train,
|
||||||
loss_test=loss_test)
|
loss_test=loss_test,
|
||||||
|
entropy_train=entropy_train,
|
||||||
|
entropy_test=entropy_test)
|
||||||
|
|
||||||
def test_slice_entire_dataset(self):
|
def test_slice_entire_dataset(self):
|
||||||
entire_dataset_slice = SingleSliceSpec()
|
entire_dataset_slice = SingleSliceSpec()
|
||||||
|
@ -158,6 +162,12 @@ class GetSliceTest(absltest.TestCase):
|
||||||
self.assertLen(output.loss_test, 1)
|
self.assertLen(output.loss_test, 1)
|
||||||
self.assertTrue((output.loss_train == [2, 4]).all())
|
self.assertTrue((output.loss_train == [2, 4]).all())
|
||||||
self.assertTrue((output.loss_test == [0.5]).all())
|
self.assertTrue((output.loss_test == [0.5]).all())
|
||||||
|
|
||||||
|
# Check entropy
|
||||||
|
self.assertLen(output.entropy_train, 2)
|
||||||
|
self.assertLen(output.entropy_test, 1)
|
||||||
|
self.assertTrue((output.entropy_train == [0.4, 0.6]).all())
|
||||||
|
self.assertTrue((output.entropy_test == [15]).all())
|
||||||
|
|
||||||
def test_slice_by_percentile(self):
|
def test_slice_by_percentile(self):
|
||||||
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))
|
||||||
|
|
Loading…
Reference in a new issue