Fix nits
This commit is contained in:
parent
ed2bdcadfa
commit
afe3944b1d
3 changed files with 23 additions and 23 deletions
|
@ -374,10 +374,10 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
|||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||
|
||||
|
||||
def _is_generator(gen, gen_name):
|
||||
"""Checks whether gen is a generator."""
|
||||
if gen is not None and not isinstance(gen, Iterator):
|
||||
raise ValueError('%s should be a generator.' % gen_name)
|
||||
def _is_iterator(obj, obj_name):
|
||||
"""Checks whether obj is a generator."""
|
||||
if obj is not None and not isinstance(obj, Iterator):
|
||||
raise ValueError('%s should be a generator.' % obj_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -393,7 +393,7 @@ class Seq2SeqAttackInputData:
|
|||
labels_train: Iterator[np.ndarray] = None
|
||||
labels_test: Iterator[np.ndarray] = None
|
||||
|
||||
# Denotes size of the target sequence vocabulary.
|
||||
# Size of the target sequence vocabulary.
|
||||
vocab_size: int = None
|
||||
|
||||
# Train, test size = number of batches in training, test set.
|
||||
|
@ -431,10 +431,10 @@ class Seq2SeqAttackInputData:
|
|||
if self.test_size is not None and not int:
|
||||
raise ValueError('test_size should be of integer type')
|
||||
|
||||
_is_generator(self.logits_train, 'logits_train')
|
||||
_is_generator(self.logits_test, 'logits_test')
|
||||
_is_generator(self.labels_train, 'labels_train')
|
||||
_is_generator(self.labels_test, 'labels_test')
|
||||
_is_iterator(self.logits_train, 'logits_train')
|
||||
_is_iterator(self.logits_test, 'logits_test')
|
||||
_is_iterator(self.labels_train, 'labels_train')
|
||||
_is_iterator(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
|
|
|
@ -36,20 +36,18 @@ def get_test_input(n_train, n_test):
|
|||
|
||||
|
||||
def get_seq2seq_test_input(n_train, n_test, max_seq_in_batch, max_tokens_in_sequence, vocab_size, seed=None):
|
||||
"""Get example inputs for attacks on seq2seq models."""
|
||||
"""Returns example inputs for attacks on seq2seq models."""
|
||||
if seed is not None:
|
||||
np.random.seed(seed=seed)
|
||||
|
||||
logits_train = []
|
||||
labels_train = []
|
||||
logits_train, labels_train = [], []
|
||||
for i in range(n_train):
|
||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||
batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size)
|
||||
logits_train.append(batch_logits)
|
||||
labels_train.append(batch_labels)
|
||||
|
||||
logits_test = []
|
||||
labels_test = []
|
||||
logits_test, labels_test = [], []
|
||||
for i in range(n_test):
|
||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||
batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size)
|
||||
|
|
|
@ -23,6 +23,7 @@ from sklearn import linear_model
|
|||
from sklearn import model_selection
|
||||
from sklearn import neighbors
|
||||
from sklearn import neural_network
|
||||
from typing import Iterator, List
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
|
@ -134,8 +135,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|||
attack_input_test = _get_average_ranks(attack_input_data.logits_test, attack_input_data.labels_test)
|
||||
|
||||
if balance:
|
||||
min_size = min(attack_input_data.train_size,
|
||||
attack_input_data.test_size)
|
||||
min_size = min(len(attack_input_train),
|
||||
len(attack_input_test))
|
||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||
min_size)
|
||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||
|
@ -159,7 +160,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|||
is_training_labels_test)
|
||||
|
||||
|
||||
def _get_average_ranks(logits, labels):
|
||||
def _get_average_ranks(logits: Iterator[np.ndarray],
|
||||
labels: Iterator[np.ndarray]) -> np.ndarray:
|
||||
"""Returns the average rank of tokens in a batch of sequences.
|
||||
|
||||
Args:
|
||||
|
@ -180,7 +182,8 @@ def _get_average_ranks(logits, labels):
|
|||
return np.array(ranks)
|
||||
|
||||
|
||||
def _get_ranks_for_sequence(logits, labels):
|
||||
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||
labels: np.ndarray) -> List:
|
||||
"""Returns ranks for a sequence.
|
||||
|
||||
Args:
|
||||
|
@ -190,11 +193,10 @@ def _get_ranks_for_sequence(logits, labels):
|
|||
Returns:
|
||||
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
|
||||
"""
|
||||
scores = -logits
|
||||
all_ranks = np.empty_like(scores)
|
||||
for i, s in enumerate(scores):
|
||||
all_ranks[i] = rankdata(s, method='min') - 1
|
||||
sequence_ranks = all_ranks[np.arange(len(all_ranks)), labels.astype(int)].tolist()
|
||||
sequence_ranks = []
|
||||
for logit, label in zip(logits, labels.astype(int)):
|
||||
rank = rankdata(-logit, method='min')[label] - 1.0
|
||||
sequence_ranks.append(rank)
|
||||
|
||||
return sequence_ranks
|
||||
|
||||
|
|
Loading…
Reference in a new issue