diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia.py index d848cdd..6f7fc82 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia.py @@ -23,19 +23,12 @@ from typing import Iterator, List, Optional, Tuple import numpy as np from scipy import stats -from sklearn import metrics -from sklearn import model_selection import tensorflow as tf -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import _sample_multidimensional_array -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import AttackerData def _is_iterator(obj, obj_name): @@ -266,66 +259,6 @@ def _get_batch_accuracy_metrics( return batch_correct_preds, batch_total_preds -def create_seq2seq_attacker_data( - attack_input_data: Seq2SeqAttackInputData, - test_fraction: float = 0.25, - balance: bool = True, - privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata() -) -> AttackerData: - """Prepares Seq2SeqAttackInputData to train ML attackers. - - Uses logits and losses to generate ranks and performs a random train-test - split. - - Also computes metadata (loss, accuracy) for the model under attack - and populates respective fields of PrivacyReportMetadata. - - Args: - attack_input_data: Original Seq2SeqAttackInputData - test_fraction: Fraction of the dataset to include in the test split. - balance: Whether the training and test sets for the membership inference - attacker should have a balanced (roughly equal) number of samples from the - training and test sets used to develop the model under attack. - privacy_report_metadata: the metadata of the model under attack. - - Returns: - AttackerData. - """ - attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata( - attack_input_data.logits_train, attack_input_data.labels_train) - attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata( - attack_input_data.logits_test, attack_input_data.labels_test) - - if balance: - min_size = min(len(attack_input_train), len(attack_input_test)) - attack_input_train = _sample_multidimensional_array(attack_input_train, - min_size) - attack_input_test = _sample_multidimensional_array(attack_input_test, - min_size) - - features_all = np.concatenate((attack_input_train, attack_input_test)) - ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0] - - # Reshape for classifying one-dimensional features - features_all = features_all.reshape(-1, 1) - - labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest)))) - - # Perform a train-test split - features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split( - features_all, labels_all, test_size=test_fraction, stratify=labels_all) - - # Populate accuracy, loss fields in privacy report metadata - privacy_report_metadata.loss_train = loss_train - privacy_report_metadata.loss_test = loss_test - privacy_report_metadata.accuracy_train = accuracy_train - privacy_report_metadata.accuracy_test = accuracy_test - - return AttackerData(features_train, is_training_labels_train, features_test, - is_training_labels_test, - DataSize(ntrain=ntrain, ntest=ntest)) - - def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, privacy_report_metadata: PrivacyReportMetadata = None, balance_attacker_training: bool = True) -> AttackResults: @@ -343,39 +276,23 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, the attack result. """ attack_input.validate() + attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata( + attack_input.logits_train, attack_input.labels_train) + attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata( + attack_input.logits_test, attack_input.labels_test) - # The attacker uses the average rank (a single number) of a seq2seq dataset - # record to determine membership. So only Logistic Regression is supported, - # as it makes the most sense for single-number features. - attacker = models.LogisticRegressionAttacker() - - # Create attacker data and populate fields of privacy_report_metadata privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() - prepared_attacker_data = create_seq2seq_attacker_data( - attack_input_data=attack_input, - balance=balance_attacker_training, - privacy_report_metadata=privacy_report_metadata) + privacy_report_metadata.loss_train = loss_train + privacy_report_metadata.loss_test = loss_test + privacy_report_metadata.accuracy_train = accuracy_train + privacy_report_metadata.accuracy_test = accuracy_test - attacker.train_model(prepared_attacker_data.features_train, - prepared_attacker_data.is_training_labels_train) - - # Run the attacker on (permuted) test examples. - predictions_test = attacker.predict(prepared_attacker_data.features_test) - - # Generate ROC curves with predictions. - fpr, tpr, thresholds = metrics.roc_curve( - prepared_attacker_data.is_training_labels_test, predictions_test) - - roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) - - attack_results = [ - SingleAttackResult( - slice_spec=SingleSliceSpec(), - attack_type=AttackType.LOGISTIC_REGRESSION, - roc_curve=roc_curve, - data_size=prepared_attacker_data.data_size) - ] - - return AttackResults( - single_attack_results=attack_results, - privacy_report_metadata=privacy_report_metadata) + # `attack_input_train` and `attack_input_test` contains the rank of the + # ground-truth label in the logit, so smaller value means an example is + # more likely a training example. + return mia.run_attacks( + AttackInputData( + loss_train=attack_input_train, loss_test=attack_input_test), + attack_types=(AttackType.THRESHOLD_ATTACK,), + privacy_report_metadata=privacy_report_metadata, + balance_attacker_training=balance_attacker_training) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia_test.py index 03c1fb2..af8edc1 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/seq2seq_mia_test.py @@ -16,8 +16,6 @@ from absl.testing import absltest import numpy as np from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata -from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData @@ -91,154 +89,6 @@ class Seq2SeqAttackInputDataTest(absltest.TestCase): test_size=0).validate) -class Seq2SeqTrainedAttackerTest(absltest.TestCase): - - def test_create_seq2seq_attacker_data_logits_and_labels(self): - attack_input = Seq2SeqAttackInputData( - logits_train=iter([ - np.array([ - np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), - np.array([[0.4, 0.5, 0.1]], dtype=np.float32) - ], - dtype=object), - np.array( - [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], - dtype=object), - np.array([ - np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), - np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) - ], - dtype=object) - ]), - logits_test=iter([ - np.array([ - np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object), - np.array([ - np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), - np.array([[0.3, 0.35, 0.35]], dtype=np.float32) - ], - dtype=object) - ]), - labels_train=iter([ - np.array([ - np.array([2, 0], dtype=np.float32), - np.array([1], dtype=np.float32) - ], - dtype=object), - np.array([np.array([1, 0], dtype=np.float32)], dtype=object), - np.array([ - np.array([0, 1], dtype=np.float32), - np.array([1, 2], dtype=np.float32) - ], - dtype=object) - ]), - labels_test=iter([ - np.array([np.array([2, 1], dtype=np.float32)]), - np.array([ - np.array([2, 0], dtype=np.float32), - np.array([1], dtype=np.float32) - ], - dtype=object) - ]), - vocab_size=3, - train_size=3, - test_size=2) - privacy_report_metadata = PrivacyReportMetadata() - attacker_data = create_seq2seq_attacker_data( - attack_input_data=attack_input, - test_fraction=0.25, - balance=False, - privacy_report_metadata=privacy_report_metadata) - self.assertLen(attacker_data.features_train, 3) - self.assertLen(attacker_data.features_test, 2) - - for _, feature in enumerate(attacker_data.features_train): - self.assertLen(feature, 1) # each feature has one average rank - - # Tests that fields of PrivacyReportMetadata are populated. - self.assertIsNotNone(privacy_report_metadata.loss_train) - self.assertIsNotNone(privacy_report_metadata.loss_test) - self.assertIsNotNone(privacy_report_metadata.accuracy_train) - self.assertIsNotNone(privacy_report_metadata.accuracy_test) - - def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self): - attack_input = Seq2SeqAttackInputData( - logits_train=iter([ - np.array([ - np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32), - np.array([[0.4, 0.5, 0.1]], dtype=np.float32) - ], - dtype=object), - np.array( - [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)], - dtype=object), - np.array([ - np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32), - np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32) - ], - dtype=object) - ]), - logits_test=iter([ - np.array([ - np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object), - np.array([ - np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32), - np.array([[0.3, 0.35, 0.35]], dtype=np.float32) - ], - dtype=object), - np.array([ - np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32) - ], - dtype=object) - ]), - labels_train=iter([ - np.array([ - np.array([2, 0], dtype=np.float32), - np.array([1], dtype=np.float32) - ], - dtype=object), - np.array([np.array([1, 0], dtype=np.float32)], dtype=object), - np.array([ - np.array([0, 1], dtype=np.float32), - np.array([1, 2], dtype=np.float32) - ], - dtype=object) - ]), - labels_test=iter([ - np.array([np.array([2, 1], dtype=np.float32)]), - np.array([ - np.array([2, 0], dtype=np.float32), - np.array([1], dtype=np.float32) - ], - dtype=object), - np.array([np.array([2, 1], dtype=np.float32)]) - ]), - vocab_size=3, - train_size=3, - test_size=3) - privacy_report_metadata = PrivacyReportMetadata() - attacker_data = create_seq2seq_attacker_data( - attack_input_data=attack_input, - test_fraction=0.33, - balance=True, - privacy_report_metadata=privacy_report_metadata) - self.assertLen(attacker_data.features_train, 4) - self.assertLen(attacker_data.features_test, 2) - - for _, feature in enumerate(attacker_data.features_train): - self.assertLen(feature, 1) # each feature has one average rank - - # Tests that fields of PrivacyReportMetadata are populated. - self.assertIsNotNone(privacy_report_metadata.loss_train) - self.assertIsNotNone(privacy_report_metadata.loss_test) - self.assertIsNotNone(privacy_report_metadata.accuracy_train) - self.assertIsNotNone(privacy_report_metadata.accuracy_test) - - def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size): num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence, @@ -323,7 +173,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase): max_tokens_in_sequence=5, vocab_size=2)) seq2seq_result = list(result.single_attack_results)[0] - self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION) + self.assertEqual(seq2seq_result.attack_type, AttackType.THRESHOLD_ATTACK) def test_run_seq2seq_attack_calculates_correct_auc(self): result = run_seq2seq_attack( @@ -337,7 +187,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase): balance_attacker_training=False) seq2seq_result = list(result.single_attack_results)[0] np.testing.assert_almost_equal( - seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) + seq2seq_result.roc_curve.get_auc(), 0.59, decimal=2) def test_run_seq2seq_attack_calculates_correct_metadata(self): attack_input = Seq2SeqAttackInputData(