diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index fd549d2..c297341 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -374,10 +374,10 @@ def _append_array_shape(arr: np.array, arr_name: str, result): result.append(' %s with shape: %s,' % (arr_name, arr.shape)) -def _is_generator(gen, gen_name): - """Checks whether gen is a generator.""" - if gen is not None and not isinstance(gen, Iterator): - raise ValueError('%s should be a generator.' % gen_name) +def _is_iterator(obj, obj_name): + """Checks whether obj is a generator.""" + if obj is not None and not isinstance(obj, Iterator): + raise ValueError('%s should be a generator.' % obj_name) @dataclass @@ -393,7 +393,7 @@ class Seq2SeqAttackInputData: labels_train: Iterator[np.ndarray] = None labels_test: Iterator[np.ndarray] = None - # Denotes size of the target sequence vocabulary. + # Size of the target sequence vocabulary. vocab_size: int = None # Train, test size = number of batches in training, test set. @@ -431,10 +431,10 @@ class Seq2SeqAttackInputData: if self.test_size is not None and not int: raise ValueError('test_size should be of integer type') - _is_generator(self.logits_train, 'logits_train') - _is_generator(self.logits_test, 'logits_test') - _is_generator(self.labels_train, 'labels_train') - _is_generator(self.labels_test, 'labels_test') + _is_iterator(self.logits_train, 'logits_train') + _is_iterator(self.logits_test, 'logits_test') + _is_iterator(self.labels_train, 'labels_train') + _is_iterator(self.labels_test, 'labels_test') def __str__(self): """Return the shapes of variables that are not None.""" diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py index 5609eab..6182df9 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py @@ -36,20 +36,18 @@ def get_test_input(n_train, n_test): def get_seq2seq_test_input(n_train, n_test, max_seq_in_batch, max_tokens_in_sequence, vocab_size, seed=None): - """Get example inputs for attacks on seq2seq models.""" + """Returns example inputs for attacks on seq2seq models.""" if seed is not None: np.random.seed(seed=seed) - logits_train = [] - labels_train = [] + logits_train, labels_train = [], [] for i in range(n_train): num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size) logits_train.append(batch_logits) labels_train.append(batch_labels) - logits_test = [] - labels_test = [] + logits_test, labels_test = [], [] for i in range(n_test): num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1 batch_logits, batch_labels = _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, vocab_size) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 70becaf..00f9eb6 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -23,6 +23,7 @@ from sklearn import linear_model from sklearn import model_selection from sklearn import neighbors from sklearn import neural_network +from typing import Iterator, List from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData @@ -134,8 +135,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, attack_input_test = _get_average_ranks(attack_input_data.logits_test, attack_input_data.labels_test) if balance: - min_size = min(attack_input_data.train_size, - attack_input_data.test_size) + min_size = min(len(attack_input_train), + len(attack_input_test)) attack_input_train = _sample_multidimensional_array(attack_input_train, min_size) attack_input_test = _sample_multidimensional_array(attack_input_test, @@ -159,7 +160,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, is_training_labels_test) -def _get_average_ranks(logits, labels): +def _get_average_ranks(logits: Iterator[np.ndarray], + labels: Iterator[np.ndarray]) -> np.ndarray: """Returns the average rank of tokens in a batch of sequences. Args: @@ -180,7 +182,8 @@ def _get_average_ranks(logits, labels): return np.array(ranks) -def _get_ranks_for_sequence(logits, labels): +def _get_ranks_for_sequence(logits: np.ndarray, + labels: np.ndarray) -> List: """Returns ranks for a sequence. Args: @@ -190,11 +193,10 @@ def _get_ranks_for_sequence(logits, labels): Returns: An array of ranks for tokens in the sequence, dim = (num_tokens, 1). """ - scores = -logits - all_ranks = np.empty_like(scores) - for i, s in enumerate(scores): - all_ranks[i] = rankdata(s, method='min') - 1 - sequence_ranks = all_ranks[np.arange(len(all_ranks)), labels.astype(int)].tolist() + sequence_ranks = [] + for logit, label in zip(logits, labels.astype(int)): + rank = rankdata(-logit, method='min')[label] - 1.0 + sequence_ranks.append(rank) return sequence_ranks