add test case for entropy attack

This commit is contained in:
Liwei Song 2020-10-23 09:30:09 -04:00
parent 893b615d72
commit 6e929da966
2 changed files with 19 additions and 0 deletions

View file

@ -279,12 +279,16 @@ class AttackInputData:
"""Returns size of the training set.""" """Returns size of the training set."""
if self.loss_train is not None: if self.loss_train is not None:
return self.loss_train.size return self.loss_train.size
if self.entropy_train is not None:
return self.entropy_train.size
return self.logits_or_probs_train.shape[0] return self.logits_or_probs_train.shape[0]
def get_test_size(self): def get_test_size(self):
"""Returns size of the test set.""" """Returns size of the test set."""
if self.loss_test is not None: if self.loss_test is not None:
return self.loss_test.size return self.loss_test.size
if self.entropy_test is not None:
return self.entropy_test.size
return self.logits_or_probs_test.shape[0] return self.logits_or_probs_test.shape[0]
def validate(self): def validate(self):

View file

@ -55,6 +55,12 @@ class RunAttacksTest(absltest.TestCase):
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
def test_run_attack_threshold_entropy_sets_attack_type(self):
result = mia._run_attack(
get_test_input(100, 100), AttackType.THRESHOLD_ENTROPY_ATTACK)
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
def test_run_attack_threshold_calculates_correct_auc(self): def test_run_attack_threshold_calculates_correct_auc(self):
result = mia._run_attack( result = mia._run_attack(
AttackInputData( AttackInputData(
@ -64,6 +70,15 @@ class RunAttacksTest(absltest.TestCase):
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
def test_run_attack_threshold_entropy_calculates_correct_auc(self):
result = mia._run_attack(
AttackInputData(
entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
AttackType.THRESHOLD_ENTROPY_ATTACK)
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
def test_run_attack_by_slice(self): def test_run_attack_by_slice(self):
result = mia.run_attacks( result = mia.run_attacks(
get_test_input(100, 100), SlicingSpec(by_class=True), get_test_input(100, 100), SlicingSpec(by_class=True),