forked from 626_privacy/tensorflow_privacy
add test case for entropy attack
This commit is contained in:
parent
893b615d72
commit
6e929da966
2 changed files with 19 additions and 0 deletions
|
@ -279,12 +279,16 @@ class AttackInputData:
|
||||||
"""Returns size of the training set."""
|
"""Returns size of the training set."""
|
||||||
if self.loss_train is not None:
|
if self.loss_train is not None:
|
||||||
return self.loss_train.size
|
return self.loss_train.size
|
||||||
|
if self.entropy_train is not None:
|
||||||
|
return self.entropy_train.size
|
||||||
return self.logits_or_probs_train.shape[0]
|
return self.logits_or_probs_train.shape[0]
|
||||||
|
|
||||||
def get_test_size(self):
|
def get_test_size(self):
|
||||||
"""Returns size of the test set."""
|
"""Returns size of the test set."""
|
||||||
if self.loss_test is not None:
|
if self.loss_test is not None:
|
||||||
return self.loss_test.size
|
return self.loss_test.size
|
||||||
|
if self.entropy_test is not None:
|
||||||
|
return self.entropy_test.size
|
||||||
return self.logits_or_probs_test.shape[0]
|
return self.logits_or_probs_test.shape[0]
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
|
|
|
@ -55,6 +55,12 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_sets_attack_type(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
get_test_input(100, 100), AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
def test_run_attack_threshold_calculates_correct_auc(self):
|
def test_run_attack_threshold_calculates_correct_auc(self):
|
||||||
result = mia._run_attack(
|
result = mia._run_attack(
|
||||||
AttackInputData(
|
AttackInputData(
|
||||||
|
@ -64,6 +70,15 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
||||||
|
|
||||||
|
def test_run_attack_threshold_entropy_calculates_correct_auc(self):
|
||||||
|
result = mia._run_attack(
|
||||||
|
AttackInputData(
|
||||||
|
entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]),
|
||||||
|
entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),
|
||||||
|
AttackType.THRESHOLD_ENTROPY_ATTACK)
|
||||||
|
|
||||||
|
np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2)
|
||||||
|
|
||||||
def test_run_attack_by_slice(self):
|
def test_run_attack_by_slice(self):
|
||||||
result = mia.run_attacks(
|
result = mia.run_attacks(
|
||||||
get_test_input(100, 100), SlicingSpec(by_class=True),
|
get_test_input(100, 100), SlicingSpec(by_class=True),
|
||||||
|
|
Loading…
Reference in a new issue