Return loss, accuracy instead of updating args

This commit is contained in:
amad-person 2020-11-27 11:59:06 +08:00
parent eb215072bc
commit 981d5a95f5
2 changed files with 53 additions and 34 deletions

View file

@ -25,7 +25,9 @@ from dataclasses import dataclass
from scipy.stats import rankdata
from sklearn import metrics, model_selection
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
import tensorflow.keras.backend as K
from tensorflow.keras.losses import sparse_categorical_crossentropy
from tensorflow.keras.metrics import sparse_categorical_accuracy
from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
@ -95,7 +97,7 @@ class Seq2SeqAttackInputData:
_is_iterator(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
"""Returns the shapes of variables that are not None."""
result = ['AttackInputData(']
if self.vocab_size is not None and self.train_size is not None:
@ -119,7 +121,7 @@ class Seq2SeqAttackInputData:
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
labels: Iterator[np.ndarray]) -> (np.ndarray, float, float):
"""Returns the average rank of tokens per batch of sequences,
and the loss computed using logits and labels.
@ -137,20 +139,30 @@ def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
3. Accuracy computed over all logits and labels.
"""
ranks = []
loss = SparseCategoricalCrossentropy(from_logits=True)
accuracy = SparseCategoricalAccuracy()
loss = 0.0
dataset_length = 0.0
correct_preds = 0
total_preds = 0
for batch_logits, batch_labels in zip(logits, labels):
# Compute average rank for the current batch.
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
ranks.append(np.mean(batch_ranks))
# Update overall loss with loss of the current batch.
_update_batch_loss(batch_logits, batch_labels, loss)
# Update overall loss metrics with metrics of the current batch.
batch_loss, batch_length = _get_batch_loss_metrics(batch_logits, batch_labels)
loss += batch_loss
dataset_length += batch_length
# Update overall accuracy with accuracy of the current batch.
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
# Update overall accuracy metrics with metrics of the current batch.
batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(batch_logits, batch_labels)
correct_preds += batch_correct_preds
total_preds += batch_total_preds
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
# Compute loss and accuracy for the dataset.
loss = loss / dataset_length
accuracy = correct_preds / total_preds
return np.array(ranks), loss, accuracy
def _get_batch_ranks(batch_logits: np.ndarray,
@ -193,45 +205,53 @@ def _get_ranks_for_sequence(logits: np.ndarray,
return sequence_ranks
def _update_batch_loss(batch_logits: np.ndarray,
batch_labels: np.ndarray,
loss: SparseCategoricalCrossentropy):
"""Updates the loss metric per batch.
def _get_batch_loss_metrics(batch_logits: np.ndarray,
batch_labels: np.ndarray) -> (float, int):
"""Returns the loss, number of sequences for a batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
loss: SparseCategoricalCrossentropy loss metric.
"""
batch_loss = 0.0
batch_length = len(batch_logits)
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
loss.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
sequence_loss = sparse_categorical_crossentropy(K.constant(sequence_labels),
K.constant(sequence_logits),
from_logits=True)
batch_loss += sequence_loss.numpy().sum()
return batch_loss / batch_length, batch_length
def _update_batch_accuracy(batch_logits: np.ndarray,
batch_labels: np.ndarray,
accuracy: SparseCategoricalAccuracy):
"""Updates the accuracy metric per batch.
def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
batch_labels: np.ndarray) -> (float, float):
"""Returns the number of correct predictions, total number of predictions for a batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
accuracy: SparseCategoricalAccuracy accuracy metric.
"""
batch_correct_preds = 0.0
batch_total_preds = 0.0
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
accuracy.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
preds = sparse_categorical_accuracy(K.constant(sequence_labels),
K.constant(sequence_logits))
batch_correct_preds += preds.numpy().sum()
batch_total_preds += len(sequence_labels)
return batch_correct_preds, batch_total_preds
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata,
test_fraction: float = 0.25,
balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
"""Prepare Seq2SeqAttackInputData to train ML attackers.
balance: bool = True) -> AttackerData:
"""Prepares Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
split.
@ -241,11 +261,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
Args:
attack_input_data: Original Seq2SeqAttackInputData
privacy_report_metadata: the metadata of the model under attack.
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns:
AttackerData.
@ -276,7 +296,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate fields of privacy report metadata
# Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
@ -310,8 +331,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
attacker = models.LogisticRegressionAttacker()
# Create attacker data and populate fields of privacy_report_metadata
if privacy_report_metadata is None:
privacy_report_metadata = PrivacyReportMetadata()
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
balance=balance_attacker_training,
privacy_report_metadata=privacy_report_metadata)

View file

@ -162,7 +162,6 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
@ -345,8 +344,8 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
seed=12345),
balance_attacker_training=False)
metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)