diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 9f58740..3426739 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -279,12 +279,16 @@ class AttackInputData: """Returns size of the training set.""" if self.loss_train is not None: return self.loss_train.size + if self.entropy_train is not None: + return self.entropy_train.size return self.logits_or_probs_train.shape[0] def get_test_size(self): """Returns size of the test set.""" if self.loss_test is not None: return self.loss_test.size + if self.entropy_test is not None: + return self.entropy_test.size return self.logits_or_probs_test.shape[0] def validate(self): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index a56aa3a..40340e4 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -55,6 +55,12 @@ class RunAttacksTest(absltest.TestCase): self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) + def test_run_attack_threshold_entropy_sets_attack_type(self): + result = mia._run_attack( + get_test_input(100, 100), AttackType.THRESHOLD_ENTROPY_ATTACK) + + self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK) + def test_run_attack_threshold_calculates_correct_auc(self): result = mia._run_attack( AttackInputData( @@ -64,6 +70,15 @@ class RunAttacksTest(absltest.TestCase): np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + def test_run_attack_threshold_entropy_calculates_correct_auc(self): + result = mia._run_attack( + AttackInputData( + entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), + entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])), + AttackType.THRESHOLD_ENTROPY_ATTACK) + + np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + def test_run_attack_by_slice(self): result = mia.run_attacks( get_test_input(100, 100), SlicingSpec(by_class=True),