diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 54674e0..70becaf 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -17,6 +17,7 @@ from dataclasses import dataclass import numpy as np +from scipy.stats import rankdata from sklearn import ensemble from sklearn import linear_model from sklearn import model_selection @@ -24,6 +25,7 @@ from sklearn import neighbors from sklearn import neural_network from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData @dataclass @@ -110,6 +112,93 @@ def _column_stack(logits, loss): return np.column_stack((logits, loss)) +def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, + test_fraction: float = 0.25, + balance: bool = True) -> AttackerData: + """Prepare Seq2SeqAttackInputData to train ML attackers. + + Uses logits and losses to generate ranks and performs a random train-test split. + + Args: + attack_input_data: Original Seq2SeqAttackInputData + test_fraction: Fraction of the dataset to include in the test split. + balance: Whether the training and test sets for the membership inference + attacker should have a balanced (roughly equal) number of samples + from the training and test sets used to develop the model + under attack. + + Returns: + AttackerData. + """ + attack_input_train = _get_average_ranks(attack_input_data.logits_train, attack_input_data.labels_train) + attack_input_test = _get_average_ranks(attack_input_data.logits_test, attack_input_data.labels_test) + + if balance: + min_size = min(attack_input_data.train_size, + attack_input_data.test_size) + attack_input_train = _sample_multidimensional_array(attack_input_train, + min_size) + attack_input_test = _sample_multidimensional_array(attack_input_test, + min_size) + + features_all = np.concatenate((attack_input_train, attack_input_test)) + + # Reshape for classifying one-dimensional features + features_all = features_all.reshape(-1, 1) + + labels_all = np.concatenate( + ((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test))))) + + # Perform a train-test split + features_train, features_test, \ + is_training_labels_train, is_training_labels_test = \ + model_selection.train_test_split( + features_all, labels_all, test_size=test_fraction, stratify=labels_all) + + return AttackerData(features_train, is_training_labels_train, features_test, + is_training_labels_test) + + +def _get_average_ranks(logits, labels): + """Returns the average rank of tokens in a batch of sequences. + + Args: + logits: Logits returned by a seq2seq model, dim = (num_batches, num_sequences, num_tokens, vocab_size). + labels: Target labels for the seq2seq model, dim = (num_batches, num_sequences, num_tokens, 1). + + Returns: + An array of average ranks, dim = (num_batches, 1). + Each average rank is calculated over ranks of tokens in sequences of a particular batch. + """ + ranks = [] + for batch_logits, batch_labels in zip(logits, labels): + batch_ranks = [] + for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): + batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels) + ranks.append(np.mean(batch_ranks)) + + return np.array(ranks) + + +def _get_ranks_for_sequence(logits, labels): + """Returns ranks for a sequence. + + Args: + logits: Logits of a single sequence, dim = (num_tokens, vocab_size). + labels: Target labels of a single sequence, dim = (num_tokens, 1). + + Returns: + An array of ranks for tokens in the sequence, dim = (num_tokens, 1). + """ + scores = -logits + all_ranks = np.empty_like(scores) + for i, s in enumerate(scores): + all_ranks[i] = rankdata(s, method='min') - 1 + sequence_ranks = all_ranks[np.arange(len(all_ranks)), labels.astype(int)].tolist() + + return sequence_ranks + + class TrainedAttacker: """Base class for training attack models.""" model = None