From d1c1746cdbbc10e4b380a0968057c34ea391dd83 Mon Sep 17 00:00:00 2001 From: amad-person Date: Fri, 6 Nov 2020 16:44:52 +0800 Subject: [PATCH] Add membership inference attack for seq2seq models --- .../membership_inference_attack.py | 47 ++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index f731958..f5a0f6a 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -19,12 +19,13 @@ This file belongs to the new API for membership inference attacks. This file will be renamed to membership_inference_attack.py after the old API is removed. """ -from typing import Iterable +from typing import Iterable, Union import numpy as np from sklearn import metrics from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ @@ -37,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice -def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: +def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec: if hasattr(data, 'slice_spec'): return data.slice_spec return SingleSliceSpec() @@ -170,6 +171,48 @@ def run_attacks(attack_input: AttackInputData, privacy_report_metadata=privacy_report_metadata) +def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, + privacy_report_metadata: PrivacyReportMetadata = None, + balance_attacker_training: bool = True) -> AttackResults: + """Runs membership inference attacks on a seq2seq model. + + Args: + attack_input: input data for running an attack + privacy_report_metadata: the metadata of the model under attack. + balance_attacker_training: Whether the training and test sets for the + membership inference attacker should have a balanced (roughly equal) + number of samples from the training and test sets used to develop + the model under attack. + + Returns: + the attack result. + """ + attack_input.validate() + attacker = models.LogisticRegressionAttacker() + + prepared_attacker_data = models.create_seq2seq_attacker_data( + attack_input, balance=balance_attacker_training) + + attacker.train_model(prepared_attacker_data.features_train, + prepared_attacker_data.is_training_labels_train) + + # Run the attacker on (permuted) test examples. + predictions_test = attacker.predict(prepared_attacker_data.features_test) + + # Generate ROC curves with predictions. + fpr, tpr, thresholds = metrics.roc_curve( + prepared_attacker_data.is_training_labels_test, predictions_test) + + roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) + + attack_results = [SingleAttackResult( + slice_spec=_get_slice_spec(attack_input), + attack_type=AttackType.LOGISTIC_REGRESSION, + roc_curve=roc_curve)] + + return AttackResults(single_attack_results=attack_results) + + def _compute_missing_privacy_report_metadata( metadata: PrivacyReportMetadata, attack_input: AttackInputData) -> PrivacyReportMetadata: