diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb index 4045e77..4c80bd3 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb +++ b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb @@ -1163,7 +1163,10 @@ "fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n", "\n", "# Print a user-friendly summary of the attacks\n", - "print(attack_result.summary())" + "print(attack_result.summary())\n", + "\n", + "# Print metadata of the target model\n", + "print(attack_result.privacy_report_metadata)" ] } ], diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py index 6225d3f..0cc1daa 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py @@ -25,9 +25,11 @@ from dataclasses import dataclass from scipy.stats import rankdata from sklearn import metrics, model_selection +from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy + from tensorflow_privacy.privacy.membership_inference_attack import models -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, AttackResults, \ - RocCurve, SingleAttackResult, SingleSliceSpec, AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \ + AttackResults, RocCurve, SingleAttackResult, SingleSliceSpec, AttackType from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array @@ -116,9 +118,10 @@ class Seq2SeqAttackInputData: return '\n'.join(result) -def _get_average_ranks(logits: Iterator[np.ndarray], - labels: Iterator[np.ndarray]) -> np.ndarray: - """Returns the average rank of tokens in a batch of sequences. +def _get_attack_features_and_metadata(logits: Iterator[np.ndarray], + labels: Iterator[np.ndarray]) -> (np.ndarray, float): + """Returns the average rank of tokens per batch of sequences, + and the loss computed using logits and labels. Args: logits: Logits returned by a seq2seq model, dim = (num_batches, @@ -127,18 +130,48 @@ def _get_average_ranks(logits: Iterator[np.ndarray], num_sequences, num_tokens, 1). Returns: - An array of average ranks, dim = (num_batches, 1). + 1. An array of average ranks, dim = (num_batches, 1). Each average rank is calculated over ranks of tokens in sequences of a particular batch. + 2. Loss computed over all logits and labels. + 3. Accuracy computed over all logits and labels. """ ranks = [] + loss = SparseCategoricalCrossentropy(from_logits=True) + accuracy = SparseCategoricalAccuracy() for batch_logits, batch_labels in zip(logits, labels): - batch_ranks = [] - for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): - batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels) + # Compute average rank for the current batch. + batch_ranks = _get_batch_ranks(batch_logits, batch_labels) ranks.append(np.mean(batch_ranks)) - return np.array(ranks) + # Update overall loss with loss of the current batch. + _update_batch_loss(batch_logits, batch_labels, loss) + + # Update overall accuracy with accuracy of the current batch. + _update_batch_accuracy(batch_logits, batch_labels, accuracy) + + return np.array(ranks), loss.result().numpy(), accuracy.result().numpy() + + +def _get_batch_ranks(batch_logits: np.ndarray, + batch_labels: np.ndarray) -> np.ndarray: + """Returns the ranks of tokens in a batch of sequences. + + Args: + batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, + num_tokens, vocab_size). + batch_labels: Target labels for the seq2seq model, dim = (num_sequences, + num_tokens, 1). + + Returns: + An array of ranks of tokens in a batch of sequences, dim = (num_sequences, + num_tokens, 1) + """ + batch_ranks = [] + for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): + batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels) + + return np.array(batch_ranks) def _get_ranks_for_sequence(logits: np.ndarray, @@ -160,28 +193,67 @@ def _get_ranks_for_sequence(logits: np.ndarray, return sequence_ranks +def _update_batch_loss(batch_logits: np.ndarray, + batch_labels: np.ndarray, + loss: SparseCategoricalCrossentropy): + """Updates the loss metric per batch. + + Args: + batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, + num_tokens, vocab_size). + batch_labels: Target labels for the seq2seq model, dim = (num_sequences, + num_tokens, 1). + loss: SparseCategoricalCrossentropy loss metric. + """ + for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): + loss.update_state(sequence_labels.astype(np.float32), + sequence_logits.astype(np.float32)) + + +def _update_batch_accuracy(batch_logits: np.ndarray, + batch_labels: np.ndarray, + accuracy: SparseCategoricalAccuracy): + """Updates the accuracy metric per batch. + + Args: + batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, + num_tokens, vocab_size). + batch_labels: Target labels for the seq2seq model, dim = (num_sequences, + num_tokens, 1). + accuracy: SparseCategoricalAccuracy accuracy metric. + """ + for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): + accuracy.update_state(sequence_labels.astype(np.float32), + sequence_logits.astype(np.float32)) + + def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, test_fraction: float = 0.25, - balance: bool = True) -> AttackerData: + balance: bool = True, + privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData: """Prepare Seq2SeqAttackInputData to train ML attackers. Uses logits and losses to generate ranks and performs a random train-test split. + Also computes metadata (loss, accuracy) for the model under attack + and populates respective fields of PrivacyReportMetadata. + Args: attack_input_data: Original Seq2SeqAttackInputData test_fraction: Fraction of the dataset to include in the test split. balance: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the model under attack. + privacy_report_metadata: the metadata of the model under attack. Returns: AttackerData. """ - attack_input_train = _get_average_ranks(attack_input_data.logits_train, - attack_input_data.labels_train) - attack_input_test = _get_average_ranks(attack_input_data.logits_test, - attack_input_data.labels_test) + attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(attack_input_data.logits_train, + attack_input_data.labels_train) + attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(attack_input_data.logits_test, + attack_input_data.labels_test) if balance: min_size = min(len(attack_input_train), len(attack_input_test)) @@ -204,18 +276,24 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, model_selection.train_test_split( features_all, labels_all, test_size=test_fraction, stratify=labels_all) + # Populate fields of privacy report metadata + privacy_report_metadata.loss_train = loss_train + privacy_report_metadata.loss_test = loss_test + privacy_report_metadata.accuracy_train = accuracy_train + privacy_report_metadata.accuracy_test = accuracy_test + return AttackerData(features_train, is_training_labels_train, features_test, is_training_labels_test) def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, - unused_report_metadata: PrivacyReportMetadata = None, + privacy_report_metadata: PrivacyReportMetadata = None, balance_attacker_training: bool = True) -> AttackResults: """Runs membership inference attacks on a seq2seq model. Args: attack_input: input data for running an attack - unused_report_metadata: the metadata of the model under attack. + privacy_report_metadata: the metadata of the model under attack. balance_attacker_training: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the @@ -231,8 +309,12 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, # as it makes the most sense for single-number features. attacker = models.LogisticRegressionAttacker() - prepared_attacker_data = create_seq2seq_attacker_data( - attack_input, balance=balance_attacker_training) + # Create attacker data and populate fields of privacy_report_metadata + if privacy_report_metadata is None: + privacy_report_metadata = PrivacyReportMetadata() + prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input, + balance=balance_attacker_training, + privacy_report_metadata=privacy_report_metadata) attacker.train_model(prepared_attacker_data.features_train, prepared_attacker_data.is_training_labels_train) @@ -253,5 +335,6 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, roc_curve=roc_curve) ] - return AttackResults(single_attack_results=attack_results) - + return AttackResults( + single_attack_results=attack_results, + privacy_report_metadata=privacy_report_metadata) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py index 2da62c9..295b457 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest import numpy as np -from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType, PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \ create_seq2seq_attacker_data, run_seq2seq_attack @@ -145,13 +145,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase): vocab_size=3, train_size=3, test_size=2) - attacker_data = create_seq2seq_attacker_data(attack_input, 0.25, balance=False) + privacy_report_metadata = PrivacyReportMetadata() + attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input, + test_fraction=0.25, + balance=False, + privacy_report_metadata=privacy_report_metadata) self.assertLen(attacker_data.features_train, 3) self.assertLen(attacker_data.features_test, 2) for _, feature in enumerate(attacker_data.features_train): self.assertLen(feature, 1) # each feature has one average rank + # Tests that fields of PrivacyReportMetadata are populated. + self.assertIsNotNone(privacy_report_metadata.loss_train) + self.assertIsNotNone(privacy_report_metadata.loss_test) + self.assertIsNotNone(privacy_report_metadata.accuracy_train) + self.assertIsNotNone(privacy_report_metadata.accuracy_test) + def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self): attack_input = Seq2SeqAttackInputData( @@ -210,13 +220,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase): vocab_size=3, train_size=3, test_size=3) - attacker_data = create_seq2seq_attacker_data(attack_input, 0.33, balance=True) + privacy_report_metadata = PrivacyReportMetadata() + attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input, + test_fraction=0.33, + balance=True, + privacy_report_metadata=privacy_report_metadata) self.assertLen(attacker_data.features_train, 4) self.assertLen(attacker_data.features_test, 2) for _, feature in enumerate(attacker_data.features_train): self.assertLen(feature, 1) # each feature has one average rank + # Tests that fields of PrivacyReportMetadata are populated. + self.assertIsNotNone(privacy_report_metadata.loss_train) + self.assertIsNotNone(privacy_report_metadata.loss_test) + self.assertIsNotNone(privacy_report_metadata.accuracy_train) + self.assertIsNotNone(privacy_report_metadata.accuracy_test) + def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size): @@ -315,6 +335,21 @@ class RunSeq2SeqAttackTest(absltest.TestCase): np.testing.assert_almost_equal( seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) + def test_run_seq2seq_attack_calculates_correct_metadata(self): + result = run_seq2seq_attack(get_seq2seq_test_input( + n_train=20, + n_test=10, + max_seq_in_batch=3, + max_tokens_in_sequence=5, + vocab_size=3, + seed=12345), + balance_attacker_training=False) + metadata = result.privacy_report_metadata + np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2) + np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2) + np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2) + np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2) + if __name__ == '__main__': absltest.main()