diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py index 0cc1daa..87552b9 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia.py @@ -25,7 +25,9 @@ from dataclasses import dataclass from scipy.stats import rankdata from sklearn import metrics, model_selection -from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy +import tensorflow.keras.backend as K +from tensorflow.keras.losses import sparse_categorical_crossentropy +from tensorflow.keras.metrics import sparse_categorical_accuracy from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \ @@ -95,7 +97,7 @@ class Seq2SeqAttackInputData: _is_iterator(self.labels_test, 'labels_test') def __str__(self): - """Return the shapes of variables that are not None.""" + """Returns the shapes of variables that are not None.""" result = ['AttackInputData('] if self.vocab_size is not None and self.train_size is not None: @@ -119,7 +121,7 @@ class Seq2SeqAttackInputData: def _get_attack_features_and_metadata(logits: Iterator[np.ndarray], - labels: Iterator[np.ndarray]) -> (np.ndarray, float): + labels: Iterator[np.ndarray]) -> (np.ndarray, float, float): """Returns the average rank of tokens per batch of sequences, and the loss computed using logits and labels. @@ -137,20 +139,30 @@ def _get_attack_features_and_metadata(logits: Iterator[np.ndarray], 3. Accuracy computed over all logits and labels. """ ranks = [] - loss = SparseCategoricalCrossentropy(from_logits=True) - accuracy = SparseCategoricalAccuracy() + loss = 0.0 + dataset_length = 0.0 + correct_preds = 0 + total_preds = 0 for batch_logits, batch_labels in zip(logits, labels): # Compute average rank for the current batch. batch_ranks = _get_batch_ranks(batch_logits, batch_labels) ranks.append(np.mean(batch_ranks)) - # Update overall loss with loss of the current batch. - _update_batch_loss(batch_logits, batch_labels, loss) + # Update overall loss metrics with metrics of the current batch. + batch_loss, batch_length = _get_batch_loss_metrics(batch_logits, batch_labels) + loss += batch_loss + dataset_length += batch_length - # Update overall accuracy with accuracy of the current batch. - _update_batch_accuracy(batch_logits, batch_labels, accuracy) + # Update overall accuracy metrics with metrics of the current batch. + batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(batch_logits, batch_labels) + correct_preds += batch_correct_preds + total_preds += batch_total_preds - return np.array(ranks), loss.result().numpy(), accuracy.result().numpy() + # Compute loss and accuracy for the dataset. + loss = loss / dataset_length + accuracy = correct_preds / total_preds + + return np.array(ranks), loss, accuracy def _get_batch_ranks(batch_logits: np.ndarray, @@ -193,45 +205,53 @@ def _get_ranks_for_sequence(logits: np.ndarray, return sequence_ranks -def _update_batch_loss(batch_logits: np.ndarray, - batch_labels: np.ndarray, - loss: SparseCategoricalCrossentropy): - """Updates the loss metric per batch. +def _get_batch_loss_metrics(batch_logits: np.ndarray, + batch_labels: np.ndarray) -> (float, int): + """Returns the loss, number of sequences for a batch. Args: batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, num_tokens, vocab_size). batch_labels: Target labels for the seq2seq model, dim = (num_sequences, num_tokens, 1). - loss: SparseCategoricalCrossentropy loss metric. """ + batch_loss = 0.0 + batch_length = len(batch_logits) for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): - loss.update_state(sequence_labels.astype(np.float32), - sequence_logits.astype(np.float32)) + sequence_loss = sparse_categorical_crossentropy(K.constant(sequence_labels), + K.constant(sequence_logits), + from_logits=True) + batch_loss += sequence_loss.numpy().sum() + + return batch_loss / batch_length, batch_length -def _update_batch_accuracy(batch_logits: np.ndarray, - batch_labels: np.ndarray, - accuracy: SparseCategoricalAccuracy): - """Updates the accuracy metric per batch. +def _get_batch_accuracy_metrics(batch_logits: np.ndarray, + batch_labels: np.ndarray) -> (float, float): + """Returns the number of correct predictions, total number of predictions for a batch. Args: batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, num_tokens, vocab_size). batch_labels: Target labels for the seq2seq model, dim = (num_sequences, num_tokens, 1). - accuracy: SparseCategoricalAccuracy accuracy metric. """ + batch_correct_preds = 0.0 + batch_total_preds = 0.0 for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): - accuracy.update_state(sequence_labels.astype(np.float32), - sequence_logits.astype(np.float32)) + preds = sparse_categorical_accuracy(K.constant(sequence_labels), + K.constant(sequence_logits)) + batch_correct_preds += preds.numpy().sum() + batch_total_preds += len(sequence_labels) + + return batch_correct_preds, batch_total_preds def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, + privacy_report_metadata: PrivacyReportMetadata, test_fraction: float = 0.25, - balance: bool = True, - privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData: - """Prepare Seq2SeqAttackInputData to train ML attackers. + balance: bool = True) -> AttackerData: + """Prepares Seq2SeqAttackInputData to train ML attackers. Uses logits and losses to generate ranks and performs a random train-test split. @@ -241,11 +261,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, Args: attack_input_data: Original Seq2SeqAttackInputData + privacy_report_metadata: the metadata of the model under attack. test_fraction: Fraction of the dataset to include in the test split. balance: Whether the training and test sets for the membership inference attacker should have a balanced (roughly equal) number of samples from the training and test sets used to develop the model under attack. - privacy_report_metadata: the metadata of the model under attack. Returns: AttackerData. @@ -276,7 +296,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, model_selection.train_test_split( features_all, labels_all, test_size=test_fraction, stratify=labels_all) - # Populate fields of privacy report metadata + # Populate accuracy, loss fields in privacy report metadata + privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() privacy_report_metadata.loss_train = loss_train privacy_report_metadata.loss_test = loss_test privacy_report_metadata.accuracy_train = accuracy_train @@ -310,8 +331,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, attacker = models.LogisticRegressionAttacker() # Create attacker data and populate fields of privacy_report_metadata - if privacy_report_metadata is None: - privacy_report_metadata = PrivacyReportMetadata() + privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input, balance=balance_attacker_training, privacy_report_metadata=privacy_report_metadata) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py index 295b457..650f696 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/seq2seq_mia_test.py @@ -162,7 +162,6 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase): self.assertIsNotNone(privacy_report_metadata.accuracy_train) self.assertIsNotNone(privacy_report_metadata.accuracy_test) - def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self): attack_input = Seq2SeqAttackInputData( logits_train=iter([ @@ -345,8 +344,8 @@ class RunSeq2SeqAttackTest(absltest.TestCase): seed=12345), balance_attacker_training=False) metadata = result.privacy_report_metadata - np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2) - np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2) + np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2) + np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)