add entropy tests
This commit is contained in:
parent
9b2e6a55b6
commit
0e1c1eeef3
1 changed files with 32 additions and 0 deletions
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
|
@ -65,6 +66,33 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
np.testing.assert_equal(attack_input.get_loss_test().tolist(),
|
||||
[1.0, 4.0, 6.0])
|
||||
|
||||
def test_get_entropy(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||
labels_train=np.array([0, 2]),
|
||||
labels_test=np.array([0, 2]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
|
||||
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [2*_log_value(0), 0])
|
||||
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
|
||||
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
|
||||
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [0, 0])
|
||||
|
||||
def test_get_entropy_explicitly_provided(self):
|
||||
attack_input = AttackInputData(
|
||||
entropy_train=np.array([0.0, 2.0, 1.0]),
|
||||
entropy_test=np.array([0.5, 3.0, 5.0]))
|
||||
|
||||
np.testing.assert_equal(attack_input.get_entropy_train().tolist(),
|
||||
[0.0, 2.0, 1.0])
|
||||
np.testing.assert_equal(attack_input.get_entropy_test().tolist(),
|
||||
[0.5, 3.0, 5.0])
|
||||
|
||||
def test_validator(self):
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(logits_train=np.array([])).validate)
|
||||
|
@ -72,12 +100,16 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
AttackInputData(labels_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(loss_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(entropy_train=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(logits_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(labels_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(loss_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError,
|
||||
AttackInputData(entropy_test=np.array([])).validate)
|
||||
self.assertRaises(ValueError, AttackInputData().validate)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue