diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index fbafcd6..2f7c205 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -114,11 +114,13 @@ class AttackType(enum.Enum): RANDOM_FOREST = 'rf' K_NEAREST_NEIGHBORS = 'knn' THRESHOLD_ATTACK = 'threshold' + THRESHOLD_ENTROPY_ATTACK = 'threshold-entropy' @property def is_trained_attack(self): """Returns whether this type of attack requires training a model.""" - return self != AttackType.THRESHOLD_ATTACK + return (self != AttackType.THRESHOLD_ATTACK) and ( + self != AttackType.THRESHOLD_ENTROPY_ATTACK) def __str__(self): """Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION.""" @@ -278,12 +280,16 @@ class AttackInputData: """Returns size of the training set.""" if self.loss_train is not None: return self.loss_train.size + if self.entropy_train is not None: + return self.entropy_train.size return self.logits_or_probs_train.shape[0] def get_test_size(self): """Returns size of the test set.""" if self.loss_test is not None: return self.loss_test.size + if self.entropy_test is not None: + return self.entropy_test.size return self.logits_or_probs_test.shape[0] def validate(self): diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py index 8e5e4b0..fe108a6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing.py @@ -41,12 +41,14 @@ def _slice_data_by_indices(data: AttackInputData, idx_train, result.probs_train = _slice_if_not_none(data.probs_train, idx_train) result.labels_train = _slice_if_not_none(data.labels_train, idx_train) result.loss_train = _slice_if_not_none(data.loss_train, idx_train) + result.entropy_train = _slice_if_not_none(data.entropy_train, idx_train) # Slice test data. result.logits_test = _slice_if_not_none(data.logits_test, idx_test) result.probs_test = _slice_if_not_none(data.probs_test, idx_test) result.labels_test = _slice_if_not_none(data.labels_test, idx_test) result.loss_test = _slice_if_not_none(data.loss_test, idx_test) + result.entropy_test = _slice_if_not_none(data.entropy_test, idx_test) return result diff --git a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py index 75b8a3f..48072e1 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/dataset_slicing_test.py @@ -114,6 +114,8 @@ class GetSliceTest(absltest.TestCase): labels_test = np.array([1, 2, 0, 2]) loss_train = np.array([2, 0.25, 4, 3]) loss_test = np.array([0.5, 3.5, 7, 4.5]) + entropy_train = np.array([0.4, 8, 0.6, 10]) + entropy_test = np.array([15, 10.5, 4.5, 0.3]) self.input_data = AttackInputData( logits_train=logits_train, @@ -123,7 +125,9 @@ class GetSliceTest(absltest.TestCase): labels_train=labels_train, labels_test=labels_test, loss_train=loss_train, - loss_test=loss_test) + loss_test=loss_test, + entropy_train=entropy_train, + entropy_test=entropy_test) def test_slice_entire_dataset(self): entire_dataset_slice = SingleSliceSpec() @@ -159,6 +163,12 @@ class GetSliceTest(absltest.TestCase): self.assertTrue((output.loss_train == [2, 4]).all()) self.assertTrue((output.loss_test == [0.5]).all()) + # Check entropy + self.assertLen(output.entropy_train, 2) + self.assertLen(output.entropy_test, 1) + self.assertTrue((output.entropy_train == [0.4, 0.6]).all()) + self.assertTrue((output.entropy_test == [15]).all()) + def test_slice_by_percentile(self): percentile_slice = SingleSliceSpec(SlicingFeature.PERCENTILE, (0, 50)) output = get_slice(self.input_data, percentile_slice) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index 8f877a7..f731958 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -97,6 +97,21 @@ def _run_threshold_attack(attack_input: AttackInputData): roc_curve=roc_curve) +def _run_threshold_entropy_attack(attack_input: AttackInputData): + fpr, tpr, thresholds = metrics.roc_curve( + np.concatenate((np.zeros(attack_input.get_train_size()), + np.ones(attack_input.get_test_size()))), + np.concatenate( + (attack_input.get_entropy_train(), attack_input.get_entropy_test()))) + + roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + + return SingleAttackResult( + slice_spec=_get_slice_spec(attack_input), + attack_type=AttackType.THRESHOLD_ENTROPY_ATTACK, + roc_curve=roc_curve) + + def _run_attack(attack_input: AttackInputData, attack_type: AttackType, balance_attacker_training: bool = True): @@ -104,7 +119,8 @@ def _run_attack(attack_input: AttackInputData, if attack_type.is_trained_attack: return _run_trained_attack(attack_input, attack_type, balance_attacker_training) - + if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK: + return _run_threshold_entropy_attack(attack_input) return _run_threshold_attack(attack_input) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index a56aa3a..d6b9867 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -55,6 +55,12 @@ class RunAttacksTest(absltest.TestCase): self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) + def test_run_attack_threshold_entropy_sets_attack_type(self): + result = mia._run_attack( + get_test_input(100, 100), AttackType.THRESHOLD_ENTROPY_ATTACK) + + self.assertEqual(result.attack_type, AttackType.THRESHOLD_ENTROPY_ATTACK) + def test_run_attack_threshold_calculates_correct_auc(self): result = mia._run_attack( AttackInputData( @@ -64,6 +70,15 @@ class RunAttacksTest(absltest.TestCase): np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + def test_run_attack_threshold_entropy_calculates_correct_auc(self): + result = mia._run_attack( + AttackInputData( + entropy_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), + entropy_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])), + AttackType.THRESHOLD_ENTROPY_ATTACK) + + np.testing.assert_almost_equal(result.roc_curve.get_auc(), 0.83, decimal=2) + def test_run_attack_by_slice(self): result = mia.run_attacks( get_test_input(100, 100), SlicingSpec(by_class=True),