Return loss, accuracy instead of updating args
This commit is contained in:
parent
eb215072bc
commit
981d5a95f5
2 changed files with 53 additions and 34 deletions
|
@ -25,7 +25,9 @@ from dataclasses import dataclass
|
||||||
from scipy.stats import rankdata
|
from scipy.stats import rankdata
|
||||||
from sklearn import metrics, model_selection
|
from sklearn import metrics, model_selection
|
||||||
|
|
||||||
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
|
import tensorflow.keras.backend as K
|
||||||
|
from tensorflow.keras.losses import sparse_categorical_crossentropy
|
||||||
|
from tensorflow.keras.metrics import sparse_categorical_accuracy
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
|
||||||
|
@ -95,7 +97,7 @@ class Seq2SeqAttackInputData:
|
||||||
_is_iterator(self.labels_test, 'labels_test')
|
_is_iterator(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
"""Return the shapes of variables that are not None."""
|
"""Returns the shapes of variables that are not None."""
|
||||||
result = ['AttackInputData(']
|
result = ['AttackInputData(']
|
||||||
|
|
||||||
if self.vocab_size is not None and self.train_size is not None:
|
if self.vocab_size is not None and self.train_size is not None:
|
||||||
|
@ -119,7 +121,7 @@ class Seq2SeqAttackInputData:
|
||||||
|
|
||||||
|
|
||||||
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
||||||
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
|
labels: Iterator[np.ndarray]) -> (np.ndarray, float, float):
|
||||||
"""Returns the average rank of tokens per batch of sequences,
|
"""Returns the average rank of tokens per batch of sequences,
|
||||||
and the loss computed using logits and labels.
|
and the loss computed using logits and labels.
|
||||||
|
|
||||||
|
@ -137,20 +139,30 @@ def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
||||||
3. Accuracy computed over all logits and labels.
|
3. Accuracy computed over all logits and labels.
|
||||||
"""
|
"""
|
||||||
ranks = []
|
ranks = []
|
||||||
loss = SparseCategoricalCrossentropy(from_logits=True)
|
loss = 0.0
|
||||||
accuracy = SparseCategoricalAccuracy()
|
dataset_length = 0.0
|
||||||
|
correct_preds = 0
|
||||||
|
total_preds = 0
|
||||||
for batch_logits, batch_labels in zip(logits, labels):
|
for batch_logits, batch_labels in zip(logits, labels):
|
||||||
# Compute average rank for the current batch.
|
# Compute average rank for the current batch.
|
||||||
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
||||||
ranks.append(np.mean(batch_ranks))
|
ranks.append(np.mean(batch_ranks))
|
||||||
|
|
||||||
# Update overall loss with loss of the current batch.
|
# Update overall loss metrics with metrics of the current batch.
|
||||||
_update_batch_loss(batch_logits, batch_labels, loss)
|
batch_loss, batch_length = _get_batch_loss_metrics(batch_logits, batch_labels)
|
||||||
|
loss += batch_loss
|
||||||
|
dataset_length += batch_length
|
||||||
|
|
||||||
# Update overall accuracy with accuracy of the current batch.
|
# Update overall accuracy metrics with metrics of the current batch.
|
||||||
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
|
batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(batch_logits, batch_labels)
|
||||||
|
correct_preds += batch_correct_preds
|
||||||
|
total_preds += batch_total_preds
|
||||||
|
|
||||||
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
|
# Compute loss and accuracy for the dataset.
|
||||||
|
loss = loss / dataset_length
|
||||||
|
accuracy = correct_preds / total_preds
|
||||||
|
|
||||||
|
return np.array(ranks), loss, accuracy
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_ranks(batch_logits: np.ndarray,
|
def _get_batch_ranks(batch_logits: np.ndarray,
|
||||||
|
@ -193,45 +205,53 @@ def _get_ranks_for_sequence(logits: np.ndarray,
|
||||||
return sequence_ranks
|
return sequence_ranks
|
||||||
|
|
||||||
|
|
||||||
def _update_batch_loss(batch_logits: np.ndarray,
|
def _get_batch_loss_metrics(batch_logits: np.ndarray,
|
||||||
batch_labels: np.ndarray,
|
batch_labels: np.ndarray) -> (float, int):
|
||||||
loss: SparseCategoricalCrossentropy):
|
"""Returns the loss, number of sequences for a batch.
|
||||||
"""Updates the loss metric per batch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
num_tokens, vocab_size).
|
num_tokens, vocab_size).
|
||||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
num_tokens, 1).
|
num_tokens, 1).
|
||||||
loss: SparseCategoricalCrossentropy loss metric.
|
|
||||||
"""
|
"""
|
||||||
|
batch_loss = 0.0
|
||||||
|
batch_length = len(batch_logits)
|
||||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
loss.update_state(sequence_labels.astype(np.float32),
|
sequence_loss = sparse_categorical_crossentropy(K.constant(sequence_labels),
|
||||||
sequence_logits.astype(np.float32))
|
K.constant(sequence_logits),
|
||||||
|
from_logits=True)
|
||||||
|
batch_loss += sequence_loss.numpy().sum()
|
||||||
|
|
||||||
|
return batch_loss / batch_length, batch_length
|
||||||
|
|
||||||
|
|
||||||
def _update_batch_accuracy(batch_logits: np.ndarray,
|
def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
||||||
batch_labels: np.ndarray,
|
batch_labels: np.ndarray) -> (float, float):
|
||||||
accuracy: SparseCategoricalAccuracy):
|
"""Returns the number of correct predictions, total number of predictions for a batch.
|
||||||
"""Updates the accuracy metric per batch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
num_tokens, vocab_size).
|
num_tokens, vocab_size).
|
||||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
num_tokens, 1).
|
num_tokens, 1).
|
||||||
accuracy: SparseCategoricalAccuracy accuracy metric.
|
|
||||||
"""
|
"""
|
||||||
|
batch_correct_preds = 0.0
|
||||||
|
batch_total_preds = 0.0
|
||||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
accuracy.update_state(sequence_labels.astype(np.float32),
|
preds = sparse_categorical_accuracy(K.constant(sequence_labels),
|
||||||
sequence_logits.astype(np.float32))
|
K.constant(sequence_logits))
|
||||||
|
batch_correct_preds += preds.numpy().sum()
|
||||||
|
batch_total_preds += len(sequence_labels)
|
||||||
|
|
||||||
|
return batch_correct_preds, batch_total_preds
|
||||||
|
|
||||||
|
|
||||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata,
|
||||||
test_fraction: float = 0.25,
|
test_fraction: float = 0.25,
|
||||||
balance: bool = True,
|
balance: bool = True) -> AttackerData:
|
||||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
|
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
||||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
|
||||||
|
|
||||||
Uses logits and losses to generate ranks and performs a random train-test
|
Uses logits and losses to generate ranks and performs a random train-test
|
||||||
split.
|
split.
|
||||||
|
@ -241,11 +261,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attack_input_data: Original Seq2SeqAttackInputData
|
attack_input_data: Original Seq2SeqAttackInputData
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
test_fraction: Fraction of the dataset to include in the test split.
|
test_fraction: Fraction of the dataset to include in the test split.
|
||||||
balance: Whether the training and test sets for the membership inference
|
balance: Whether the training and test sets for the membership inference
|
||||||
attacker should have a balanced (roughly equal) number of samples from the
|
attacker should have a balanced (roughly equal) number of samples from the
|
||||||
training and test sets used to develop the model under attack.
|
training and test sets used to develop the model under attack.
|
||||||
privacy_report_metadata: the metadata of the model under attack.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AttackerData.
|
AttackerData.
|
||||||
|
@ -276,7 +296,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
model_selection.train_test_split(
|
model_selection.train_test_split(
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
|
|
||||||
# Populate fields of privacy report metadata
|
# Populate accuracy, loss fields in privacy report metadata
|
||||||
|
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||||
privacy_report_metadata.loss_train = loss_train
|
privacy_report_metadata.loss_train = loss_train
|
||||||
privacy_report_metadata.loss_test = loss_test
|
privacy_report_metadata.loss_test = loss_test
|
||||||
privacy_report_metadata.accuracy_train = accuracy_train
|
privacy_report_metadata.accuracy_train = accuracy_train
|
||||||
|
@ -310,8 +331,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
attacker = models.LogisticRegressionAttacker()
|
attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
||||||
# Create attacker data and populate fields of privacy_report_metadata
|
# Create attacker data and populate fields of privacy_report_metadata
|
||||||
if privacy_report_metadata is None:
|
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||||
privacy_report_metadata = PrivacyReportMetadata()
|
|
||||||
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||||
balance=balance_attacker_training,
|
balance=balance_attacker_training,
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
|
@ -162,7 +162,6 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||||
|
|
||||||
|
|
||||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
attack_input = Seq2SeqAttackInputData(
|
attack_input = Seq2SeqAttackInputData(
|
||||||
logits_train=iter([
|
logits_train=iter([
|
||||||
|
@ -345,8 +344,8 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
seed=12345),
|
seed=12345),
|
||||||
balance_attacker_training=False)
|
balance_attacker_training=False)
|
||||||
metadata = result.privacy_report_metadata
|
metadata = result.privacy_report_metadata
|
||||||
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
|
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
|
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
||||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue