Add membership inference attack for seq2seq models

This commit is contained in:
amad-person 2020-11-06 16:44:52 +08:00
parent cd57910e5c
commit d1c1746cdb

View file

@ -19,12 +19,13 @@ This file belongs to the new API for membership inference attacks. This file
will be renamed to membership_inference_attack.py after the old API is removed. will be renamed to membership_inference_attack.py after the old API is removed.
""" """
from typing import Iterable from typing import Iterable, Union
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
@ -37,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec:
if hasattr(data, 'slice_spec'): if hasattr(data, 'slice_spec'):
return data.slice_spec return data.slice_spec
return SingleSliceSpec() return SingleSliceSpec()
@ -170,6 +171,48 @@ def run_attacks(attack_input: AttackInputData,
privacy_report_metadata=privacy_report_metadata) privacy_report_metadata=privacy_report_metadata)
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults:
"""Runs membership inference attacks on a seq2seq model.
Args:
attack_input: input data for running an attack
privacy_report_metadata: the metadata of the model under attack.
balance_attacker_training: Whether the training and test sets for the
membership inference attacker should have a balanced (roughly equal)
number of samples from the training and test sets used to develop
the model under attack.
Returns:
the attack result.
"""
attack_input.validate()
attacker = models.LogisticRegressionAttacker()
prepared_attacker_data = models.create_seq2seq_attacker_data(
attack_input, balance=balance_attacker_training)
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
# Run the attacker on (permuted) test examples.
predictions_test = attacker.predict(prepared_attacker_data.features_test)
# Generate ROC curves with predictions.
fpr, tpr, thresholds = metrics.roc_curve(
prepared_attacker_data.is_training_labels_test, predictions_test)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
attack_results = [SingleAttackResult(
slice_spec=_get_slice_spec(attack_input),
attack_type=AttackType.LOGISTIC_REGRESSION,
roc_curve=roc_curve)]
return AttackResults(single_attack_results=attack_results)
def _compute_missing_privacy_report_metadata( def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata, metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata: attack_input: AttackInputData) -> PrivacyReportMetadata: