diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 8314d95..199d32a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve @@ -65,6 +66,33 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_equal(attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + def test_get_entropy(self): + attack_input = AttackInputData( + logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]), + labels_train=np.array([0, 2]), + labels_test=np.array([0, 2])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [2*_log_value(0), 0]) + + attack_input = AttackInputData( + logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]), + logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [0, 0]) + + def test_get_entropy_explicitly_provided(self): + attack_input = AttackInputData( + entropy_train=np.array([0.0, 2.0, 1.0]), + entropy_test=np.array([0.5, 3.0, 5.0])) + + np.testing.assert_equal(attack_input.get_entropy_train().tolist(), + [0.0, 2.0, 1.0]) + np.testing.assert_equal(attack_input.get_entropy_test().tolist(), + [0.5, 3.0, 5.0]) + def test_validator(self): self.assertRaises(ValueError, AttackInputData(logits_train=np.array([])).validate) @@ -72,12 +100,16 @@ class AttackInputDataTest(absltest.TestCase): AttackInputData(labels_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(loss_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(entropy_train=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(logits_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(labels_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData(loss_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(entropy_test=np.array([])).validate) self.assertRaises(ValueError, AttackInputData().validate)