add entropy tests

This commit is contained in:
Liwei Song 2020-09-02 11:37:12 -04:00
parent 9b2e6a55b6
commit 0e1c1eeef3

View file

@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
@ -65,6 +66,33 @@ class AttackInputDataTest(absltest.TestCase):
np.testing.assert_equal(attack_input.get_loss_test().tolist(), np.testing.assert_equal(attack_input.get_loss_test().tolist(),
[1.0, 4.0, 6.0]) [1.0, 4.0, 6.0])
def test_get_entropy(self):
attack_input = AttackInputData(
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]),
labels_train=np.array([0, 2]),
labels_test=np.array([0, 2]))
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [2*_log_value(0), 0])
attack_input = AttackInputData(
logits_train=np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
logits_test=np.array([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]))
np.testing.assert_equal(attack_input.get_entropy_train().tolist(), [0, 0])
np.testing.assert_equal(attack_input.get_entropy_test().tolist(), [0, 0])
def test_get_entropy_explicitly_provided(self):
attack_input = AttackInputData(
entropy_train=np.array([0.0, 2.0, 1.0]),
entropy_test=np.array([0.5, 3.0, 5.0]))
np.testing.assert_equal(attack_input.get_entropy_train().tolist(),
[0.0, 2.0, 1.0])
np.testing.assert_equal(attack_input.get_entropy_test().tolist(),
[0.5, 3.0, 5.0])
def test_validator(self): def test_validator(self):
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(logits_train=np.array([])).validate) AttackInputData(logits_train=np.array([])).validate)
@ -72,12 +100,16 @@ class AttackInputDataTest(absltest.TestCase):
AttackInputData(labels_train=np.array([])).validate) AttackInputData(labels_train=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(loss_train=np.array([])).validate) AttackInputData(loss_train=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(entropy_train=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(logits_test=np.array([])).validate) AttackInputData(logits_test=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(labels_test=np.array([])).validate) AttackInputData(labels_test=np.array([])).validate)
self.assertRaises(ValueError, self.assertRaises(ValueError,
AttackInputData(loss_test=np.array([])).validate) AttackInputData(loss_test=np.array([])).validate)
self.assertRaises(ValueError,
AttackInputData(entropy_test=np.array([])).validate)
self.assertRaises(ValueError, AttackInputData().validate) self.assertRaises(ValueError, AttackInputData().validate)