forked from 626_privacy/tensorflow_privacy
Add membership inference attack for seq2seq models
This commit is contained in:
parent
cd57910e5c
commit
d1c1746cdb
1 changed files with 45 additions and 2 deletions
|
@ -19,12 +19,13 @@ This file belongs to the new API for membership inference attacks. This file
|
||||||
will be renamed to membership_inference_attack.py after the old API is removed.
|
will be renamed to membership_inference_attack.py after the old API is removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
|
@ -37,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec:
|
||||||
if hasattr(data, 'slice_spec'):
|
if hasattr(data, 'slice_spec'):
|
||||||
return data.slice_spec
|
return data.slice_spec
|
||||||
return SingleSliceSpec()
|
return SingleSliceSpec()
|
||||||
|
@ -170,6 +171,48 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||||
|
balance_attacker_training: bool = True) -> AttackResults:
|
||||||
|
"""Runs membership inference attacks on a seq2seq model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input: input data for running an attack
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
balance_attacker_training: Whether the training and test sets for the
|
||||||
|
membership inference attacker should have a balanced (roughly equal)
|
||||||
|
number of samples from the training and test sets used to develop
|
||||||
|
the model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the attack result.
|
||||||
|
"""
|
||||||
|
attack_input.validate()
|
||||||
|
attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
||||||
|
prepared_attacker_data = models.create_seq2seq_attacker_data(
|
||||||
|
attack_input, balance=balance_attacker_training)
|
||||||
|
|
||||||
|
attacker.train_model(prepared_attacker_data.features_train,
|
||||||
|
prepared_attacker_data.is_training_labels_train)
|
||||||
|
|
||||||
|
# Run the attacker on (permuted) test examples.
|
||||||
|
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||||
|
|
||||||
|
# Generate ROC curves with predictions.
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||||
|
|
||||||
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
attack_results = [SingleAttackResult(
|
||||||
|
slice_spec=_get_slice_spec(attack_input),
|
||||||
|
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||||
|
roc_curve=roc_curve)]
|
||||||
|
|
||||||
|
return AttackResults(single_attack_results=attack_results)
|
||||||
|
|
||||||
|
|
||||||
def _compute_missing_privacy_report_metadata(
|
def _compute_missing_privacy_report_metadata(
|
||||||
metadata: PrivacyReportMetadata,
|
metadata: PrivacyReportMetadata,
|
||||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||||
|
|
Loading…
Reference in a new issue