From 3f40b8c465e01fd88c2ab0953d0669fcbbf6b863 Mon Sep 17 00:00:00 2001 From: Liwei Song Date: Mon, 14 Dec 2020 14:49:30 -0500 Subject: [PATCH] update attack code --- .../membership_inference_attack.py | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index ccecbd5..28ddc75 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -30,7 +30,6 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec @@ -175,54 +174,6 @@ def run_attacks(attack_input: AttackInputData, privacy_report_metadata=privacy_report_metadata) -def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, - unused_report_metadata: PrivacyReportMetadata = None, - balance_attacker_training: bool = True) -> AttackResults: - """Runs membership inference attacks on a seq2seq model. - - Args: - attack_input: input data for running an attack - unused_report_metadata: the metadata of the model under attack. - balance_attacker_training: Whether the training and test sets for the - membership inference attacker should have a balanced (roughly equal) - number of samples from the training and test sets used to develop the - model under attack. - - Returns: - the attack result. - """ - attack_input.validate() - - # The attacker uses the average rank (a single number) of a seq2seq dataset - # record to determine membership. So only Logistic Regression is supported, - # as it makes the most sense for single-number features. - attacker = models.LogisticRegressionAttacker() - - prepared_attacker_data = models.create_seq2seq_attacker_data( - attack_input, balance=balance_attacker_training) - - attacker.train_model(prepared_attacker_data.features_train, - prepared_attacker_data.is_training_labels_train) - - # Run the attacker on (permuted) test examples. - predictions_test = attacker.predict(prepared_attacker_data.features_test) - - # Generate ROC curves with predictions. - fpr, tpr, thresholds = metrics.roc_curve( - prepared_attacker_data.is_training_labels_test, predictions_test) - - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) - - attack_results = [ - SingleAttackResult( - slice_spec=SingleSliceSpec(), - attack_type=AttackType.LOGISTIC_REGRESSION, - roc_curve=roc_curve) - ] - - return AttackResults(single_attack_results=attack_results) - - def _compute_privacy_risk_score(attack_input: AttackInputData, num_bins: int = 15) -> SingleRiskScoreResult: """compute each individual point's likelihood of being a member (https://arxiv.org/abs/2003.10595)