Return loss, accuracy instead of updating args

This commit is contained in:
amad-person 2020-11-27 11:59:06 +08:00
parent eb215072bc
commit 981d5a95f5
2 changed files with 53 additions and 34 deletions

View file

@ -25,7 +25,9 @@ from dataclasses import dataclass
from scipy.stats import rankdata from scipy.stats import rankdata
from sklearn import metrics, model_selection from sklearn import metrics, model_selection
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy import tensorflow.keras.backend as K
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.metrics import sparse_categorical_accuracy
from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
@ -95,7 +97,7 @@ class Seq2SeqAttackInputData:
_is_iterator(self.labels_test, 'labels_test') _is_iterator(self.labels_test, 'labels_test')
def __str__(self): def __str__(self):
"""Return the shapes of variables that are not None.""" """Returns the shapes of variables that are not None."""
result = ['AttackInputData('] result = ['AttackInputData(']
if self.vocab_size is not None and self.train_size is not None: if self.vocab_size is not None and self.train_size is not None:
@ -119,7 +121,7 @@ class Seq2SeqAttackInputData:
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray], def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> (np.ndarray, float): labels: Iterator[np.ndarray]) -> (np.ndarray, float, float):
"""Returns the average rank of tokens per batch of sequences, """Returns the average rank of tokens per batch of sequences,
and the loss computed using logits and labels. and the loss computed using logits and labels.
@ -137,20 +139,30 @@ def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
3. Accuracy computed over all logits and labels. 3. Accuracy computed over all logits and labels.
""" """
ranks = [] ranks = []
loss = SparseCategoricalCrossentropy(from_logits=True) loss = 0.0
accuracy = SparseCategoricalAccuracy() dataset_length = 0.0
correct_preds = 0
total_preds = 0
for batch_logits, batch_labels in zip(logits, labels): for batch_logits, batch_labels in zip(logits, labels):
# Compute average rank for the current batch. # Compute average rank for the current batch.
batch_ranks = _get_batch_ranks(batch_logits, batch_labels) batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
ranks.append(np.mean(batch_ranks)) ranks.append(np.mean(batch_ranks))
# Update overall loss with loss of the current batch. # Update overall loss metrics with metrics of the current batch.
_update_batch_loss(batch_logits, batch_labels, loss) batch_loss, batch_length = _get_batch_loss_metrics(batch_logits, batch_labels)
loss += batch_loss
dataset_length += batch_length
# Update overall accuracy with accuracy of the current batch. # Update overall accuracy metrics with metrics of the current batch.
_update_batch_accuracy(batch_logits, batch_labels, accuracy) batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(batch_logits, batch_labels)
correct_preds += batch_correct_preds
total_preds += batch_total_preds
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy() # Compute loss and accuracy for the dataset.
loss = loss / dataset_length
accuracy = correct_preds / total_preds
return np.array(ranks), loss, accuracy
def _get_batch_ranks(batch_logits: np.ndarray, def _get_batch_ranks(batch_logits: np.ndarray,
@ -193,45 +205,53 @@ def _get_ranks_for_sequence(logits: np.ndarray,
return sequence_ranks return sequence_ranks
def _update_batch_loss(batch_logits: np.ndarray, def _get_batch_loss_metrics(batch_logits: np.ndarray,
batch_labels: np.ndarray, batch_labels: np.ndarray) -> (float, int):
loss: SparseCategoricalCrossentropy): """Returns the loss, number of sequences for a batch.
"""Updates the loss metric per batch.
Args: Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size). num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences, batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1). num_tokens, 1).
loss: SparseCategoricalCrossentropy loss metric.
""" """
batch_loss = 0.0
batch_length = len(batch_logits)
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
loss.update_state(sequence_labels.astype(np.float32), sequence_loss = sparse_categorical_crossentropy(K.constant(sequence_labels),
sequence_logits.astype(np.float32)) K.constant(sequence_logits),
from_logits=True)
batch_loss += sequence_loss.numpy().sum()
return batch_loss / batch_length, batch_length
def _update_batch_accuracy(batch_logits: np.ndarray, def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
batch_labels: np.ndarray, batch_labels: np.ndarray) -> (float, float):
accuracy: SparseCategoricalAccuracy): """Returns the number of correct predictions, total number of predictions for a batch.
"""Updates the accuracy metric per batch.
Args: Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences, batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size). num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences, batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1). num_tokens, 1).
accuracy: SparseCategoricalAccuracy accuracy metric.
""" """
batch_correct_preds = 0.0
batch_total_preds = 0.0
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
accuracy.update_state(sequence_labels.astype(np.float32), preds = sparse_categorical_accuracy(K.constant(sequence_labels),
sequence_logits.astype(np.float32)) K.constant(sequence_logits))
batch_correct_preds += preds.numpy().sum()
batch_total_preds += len(sequence_labels)
return batch_correct_preds, batch_total_preds
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata,
test_fraction: float = 0.25, test_fraction: float = 0.25,
balance: bool = True, balance: bool = True) -> AttackerData:
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData: """Prepares Seq2SeqAttackInputData to train ML attackers.
"""Prepare Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test Uses logits and losses to generate ranks and performs a random train-test
split. split.
@ -241,11 +261,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
Args: Args:
attack_input_data: Original Seq2SeqAttackInputData attack_input_data: Original Seq2SeqAttackInputData
privacy_report_metadata: the metadata of the model under attack.
test_fraction: Fraction of the dataset to include in the test split. test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack. training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns: Returns:
AttackerData. AttackerData.
@ -276,7 +296,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
model_selection.train_test_split( model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all) features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate fields of privacy report metadata # Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
privacy_report_metadata.loss_train = loss_train privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train privacy_report_metadata.accuracy_train = accuracy_train
@ -310,8 +331,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
attacker = models.LogisticRegressionAttacker() attacker = models.LogisticRegressionAttacker()
# Create attacker data and populate fields of privacy_report_metadata # Create attacker data and populate fields of privacy_report_metadata
if privacy_report_metadata is None: privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
privacy_report_metadata = PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input, prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
balance=balance_attacker_training, balance=balance_attacker_training,
privacy_report_metadata=privacy_report_metadata) privacy_report_metadata=privacy_report_metadata)

View file

@ -162,7 +162,6 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
self.assertIsNotNone(privacy_report_metadata.accuracy_train) self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test) self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self): def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData( attack_input = Seq2SeqAttackInputData(
logits_train=iter([ logits_train=iter([
@ -345,8 +344,8 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
seed=12345), seed=12345),
balance_attacker_training=False) balance_attacker_training=False)
metadata = result.privacy_report_metadata metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2) np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2) np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2) np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)