Add rank generation code

This commit is contained in:
amad-person 2020-11-06 16:43:46 +08:00
parent 9f07f2a871
commit cd57910e5c

View file

@ -17,6 +17,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from scipy.stats import rankdata
from sklearn import ensemble from sklearn import ensemble
from sklearn import linear_model from sklearn import linear_model
from sklearn import model_selection from sklearn import model_selection
@ -24,6 +25,7 @@ from sklearn import neighbors
from sklearn import neural_network from sklearn import neural_network
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
@dataclass @dataclass
@ -110,6 +112,93 @@ def _column_stack(logits, loss):
return np.column_stack((logits, loss)) return np.column_stack((logits, loss))
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25,
balance: bool = True) -> AttackerData:
"""Prepare Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test split.
Args:
attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples
from the training and test sets used to develop the model
under attack.
Returns:
AttackerData.
"""
attack_input_train = _get_average_ranks(attack_input_data.logits_train, attack_input_data.labels_train)
attack_input_test = _get_average_ranks(attack_input_data.logits_test, attack_input_data.labels_test)
if balance:
min_size = min(attack_input_data.train_size,
attack_input_data.test_size)
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
features_all = np.concatenate((attack_input_train, attack_input_test))
# Reshape for classifying one-dimensional features
features_all = features_all.reshape(-1, 1)
labels_all = np.concatenate(
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
# Perform a train-test split
features_train, features_test, \
is_training_labels_train, is_training_labels_test = \
model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test)
def _get_average_ranks(logits, labels):
"""Returns the average rank of tokens in a batch of sequences.
Args:
logits: Logits returned by a seq2seq model, dim = (num_batches, num_sequences, num_tokens, vocab_size).
labels: Target labels for the seq2seq model, dim = (num_batches, num_sequences, num_tokens, 1).
Returns:
An array of average ranks, dim = (num_batches, 1).
Each average rank is calculated over ranks of tokens in sequences of a particular batch.
"""
ranks = []
for batch_logits, batch_labels in zip(logits, labels):
batch_ranks = []
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
ranks.append(np.mean(batch_ranks))
return np.array(ranks)
def _get_ranks_for_sequence(logits, labels):
"""Returns ranks for a sequence.
Args:
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
labels: Target labels of a single sequence, dim = (num_tokens, 1).
Returns:
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
"""
scores = -logits
all_ranks = np.empty_like(scores)
for i, s in enumerate(scores):
all_ranks[i] = rankdata(s, method='min') - 1
sequence_ranks = all_ranks[np.arange(len(all_ranks)), labels.astype(int)].tolist()
return sequence_ranks
class TrainedAttacker: class TrainedAttacker:
"""Base class for training attack models.""" """Base class for training attack models."""
model = None model = None