diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py index 75b8a3f..539870d 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py @@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase): labels_test = np.array([1, 2, 0, 2]) loss_train = np.array([2, 0.25, 4, 3]) loss_test = np.array([0.5, 3.5, 7, 4.5]) + entropy_train = np.array([0.4, 8, 0.6, 10]) + entropy_test = np.array([15, 10.5, 4.5, 0.3]) self.input_data = AttackInputData( logits_train=logits_train, @@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase): labels_train=labels_train, labels_test=labels_test, loss_train=loss_train, - loss_test=loss_test) + loss_test=loss_test, + entropy_train=entropy_train, + entropy_test=entropy_test) def test_slice_entire_dataset(self): entire_dataset_slice = SingleSliceSpec() @@ -158,6 +162,12 @@ class GetSliceTest(absltest.TestCase): self.assertLen(output.loss_test, 1) self.assertTrue((output.loss_train == [2, 4]).all()) self.assertTrue((output.loss_test == [0.5]).all()) + + # Check entropy + self.assertLen(output.entropy_train, 2) + self.assertLen(output.entropy_test, 1) + self.assertTrue((output.entropy_train == [0.4, 0.6]).all()) + self.assertTrue((output.entropy_test == [15]).all()) def test_slice_by_percentile(self): percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50))