diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py index 444b8bf..55d5656 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new.py @@ -43,7 +43,7 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: return SingleSliceSpec() -def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): +def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): """Classification attack done by ML models.""" attacker = None @@ -79,7 +79,7 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType): roc_curve=roc_curve) -def run_threshold_attack(attack_input: AttackInputData): +def _run_threshold_attack(attack_input: AttackInputData): fpr, tpr, thresholds = metrics.roc_curve( np.concatenate((np.zeros(attack_input.get_train_size()), np.ones(attack_input.get_test_size()))), @@ -94,12 +94,12 @@ def run_threshold_attack(attack_input: AttackInputData): roc_curve=roc_curve) -def run_attack(attack_input: AttackInputData, attack_type: AttackType): +def _run_attack(attack_input: AttackInputData, attack_type: AttackType): attack_input.validate() if attack_type.is_trained_attack: - return run_trained_attack(attack_input, attack_type) + return _run_trained_attack(attack_input, attack_type) - return run_threshold_attack(attack_input) + return _run_threshold_attack(attack_input) def run_attacks( @@ -107,7 +107,20 @@ def run_attacks( slicing_spec: SlicingSpec = None, attack_types: Iterable[AttackType] = (AttackType.THRESHOLD_ATTACK,), privacy_report_metadata: PrivacyReportMetadata = None) -> AttackResults: - """Run all attacks.""" + """Runs membership inference attacks on a classification model. + + It runs attacks specified by attack_types on each attack_input slice which is + specified by slicing_spec. + + Args: + attack_input: input data for running an attack + slicing_spec: specifies attack_input slices to run attack on + attack_types: attacks to run + privacy_report_metadata: the metadata of the model under attack. + + Returns: + the attack result. + """ attack_input.validate() attack_results = [] @@ -118,7 +131,7 @@ def run_attacks( for single_slice_spec in input_slice_specs: attack_input_slice = get_slice(attack_input, single_slice_spec) for attack_type in attack_types: - attack_results.append(run_attack(attack_input_slice, attack_type)) + attack_results.append(_run_attack(attack_input_slice, attack_type)) privacy_report_metadata = _compute_missing_privacy_report_metadata( privacy_report_metadata, attack_input) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py index 1803fa0..42881f4 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_new_test.py @@ -43,19 +43,19 @@ class RunAttacksTest(absltest.TestCase): self.assertLen(result.single_attack_results, 2) def test_run_attack_trained_sets_attack_type(self): - result = mia.run_attack( + result = mia._run_attack( get_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) self.assertEqual(result.attack_type, AttackType.LOGISTIC_REGRESSION) def test_run_attack_threshold_sets_attack_type(self): - result = mia.run_attack( + result = mia._run_attack( get_test_input(100, 100), AttackType.THRESHOLD_ATTACK) self.assertEqual(result.attack_type, AttackType.THRESHOLD_ATTACK) def test_run_attack_threshold_calculates_correct_auc(self): - result = mia.run_attack( + result = mia._run_attack( AttackInputData( loss_train=np.array([0.1, 0.2, 1.3, 0.4, 0.5, 0.6]), loss_test=np.array([1.1, 1.2, 1.3, 0.4, 1.5, 1.6])),