add entropy tests
This commit is contained in:
parent
9b2e6a55b6
commit
0e1c1eeef3
1 changed files with 32 additions and 0 deletions
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
@ -65,6 +66,33 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||||
[1.0, 4.0, 6.0])
|
[1.0, 4.0, 6.0])
|
||||||
|
|
||||||
|
def test_get_entropy(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||||
|
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||||
|
labels_train=np.array([0, 2]),
|
||||||
|
labels_test=np.array([0, 2]))
|
||||||
|
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [2*_log_value(0), 0])
|
||||||
|
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||||
|
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
|
||||||
|
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [0, 0])
|
||||||
|
|
||||||
|
def test_get_entropy_explicitly_provided(self):
|
||||||
|
attack_input = AttackInputData(
|
||||||
|
entropy_train=np.array([0.0, 2.0, 1.0]),
|
||||||
|
entropy_test=np.array([0.5, 3.0, 5.0]))
|
||||||
|
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_train().tolist(),
|
||||||
|
[0.0, 2.0, 1.0])
|
||||||
|
np.testing.assert_equal(attack_input.get_entropy_test().tolist(),
|
||||||
|
[0.5, 3.0, 5.0])
|
||||||
|
|
||||||
def test_validator(self):
|
def test_validator(self):
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(logits_train=np.array([])).validate)
|
AttackInputData(logits_train=np.array([])).validate)
|
||||||
|
@ -72,12 +100,16 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
AttackInputData(labels_train=np.array([])).validate)
|
AttackInputData(labels_train=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(loss_train=np.array([])).validate)
|
AttackInputData(loss_train=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(entropy_train=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(logits_test=np.array([])).validate)
|
AttackInputData(logits_test=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(labels_test=np.array([])).validate)
|
AttackInputData(labels_test=np.array([])).validate)
|
||||||
self.assertRaises(ValueError,
|
self.assertRaises(ValueError,
|
||||||
AttackInputData(loss_test=np.array([])).validate)
|
AttackInputData(loss_test=np.array([])).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
AttackInputData(entropy_test=np.array([])).validate)
|
||||||
self.assertRaises(ValueError, AttackInputData().validate)
|
self.assertRaises(ValueError, AttackInputData().validate)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue