forked from 626_privacy/tensorflow_privacy
Return loss, accuracy instead of updating args
This commit is contained in:
parent
eb215072bc
commit
981d5a95f5
2 changed files with 53 additions and 34 deletions
|
@ -25,7 +25,9 @@ from dataclasses import dataclass
|
|||
from scipy.stats import rankdata
|
||||
from sklearn import metrics, model_selection
|
||||
|
||||
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
|
||||
import tensorflow.keras.backend as K
|
||||
from tensorflow.keras.losses import sparse_categorical_crossentropy
|
||||
from tensorflow.keras.metrics import sparse_categorical_accuracy
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
|
||||
|
@ -95,7 +97,7 @@ class Seq2SeqAttackInputData:
|
|||
_is_iterator(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
"""Returns the shapes of variables that are not None."""
|
||||
result = ['AttackInputData(']
|
||||
|
||||
if self.vocab_size is not None and self.train_size is not None:
|
||||
|
@ -119,7 +121,7 @@ class Seq2SeqAttackInputData:
|
|||
|
||||
|
||||
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
||||
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
|
||||
labels: Iterator[np.ndarray]) -> (np.ndarray, float, float):
|
||||
"""Returns the average rank of tokens per batch of sequences,
|
||||
and the loss computed using logits and labels.
|
||||
|
||||
|
@ -137,20 +139,30 @@ def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
|||
3. Accuracy computed over all logits and labels.
|
||||
"""
|
||||
ranks = []
|
||||
loss = SparseCategoricalCrossentropy(from_logits=True)
|
||||
accuracy = SparseCategoricalAccuracy()
|
||||
loss = 0.0
|
||||
dataset_length = 0.0
|
||||
correct_preds = 0
|
||||
total_preds = 0
|
||||
for batch_logits, batch_labels in zip(logits, labels):
|
||||
# Compute average rank for the current batch.
|
||||
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
||||
ranks.append(np.mean(batch_ranks))
|
||||
|
||||
# Update overall loss with loss of the current batch.
|
||||
_update_batch_loss(batch_logits, batch_labels, loss)
|
||||
# Update overall loss metrics with metrics of the current batch.
|
||||
batch_loss, batch_length = _get_batch_loss_metrics(batch_logits, batch_labels)
|
||||
loss += batch_loss
|
||||
dataset_length += batch_length
|
||||
|
||||
# Update overall accuracy with accuracy of the current batch.
|
||||
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
|
||||
# Update overall accuracy metrics with metrics of the current batch.
|
||||
batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(batch_logits, batch_labels)
|
||||
correct_preds += batch_correct_preds
|
||||
total_preds += batch_total_preds
|
||||
|
||||
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
|
||||
# Compute loss and accuracy for the dataset.
|
||||
loss = loss / dataset_length
|
||||
accuracy = correct_preds / total_preds
|
||||
|
||||
return np.array(ranks), loss, accuracy
|
||||
|
||||
|
||||
def _get_batch_ranks(batch_logits: np.ndarray,
|
||||
|
@ -193,45 +205,53 @@ def _get_ranks_for_sequence(logits: np.ndarray,
|
|||
return sequence_ranks
|
||||
|
||||
|
||||
def _update_batch_loss(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray,
|
||||
loss: SparseCategoricalCrossentropy):
|
||||
"""Updates the loss metric per batch.
|
||||
def _get_batch_loss_metrics(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray) -> (float, int):
|
||||
"""Returns the loss, number of sequences for a batch.
|
||||
|
||||
Args:
|
||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||
num_tokens, vocab_size).
|
||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||
num_tokens, 1).
|
||||
loss: SparseCategoricalCrossentropy loss metric.
|
||||
"""
|
||||
batch_loss = 0.0
|
||||
batch_length = len(batch_logits)
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
loss.update_state(sequence_labels.astype(np.float32),
|
||||
sequence_logits.astype(np.float32))
|
||||
sequence_loss = sparse_categorical_crossentropy(K.constant(sequence_labels),
|
||||
K.constant(sequence_logits),
|
||||
from_logits=True)
|
||||
batch_loss += sequence_loss.numpy().sum()
|
||||
|
||||
return batch_loss / batch_length, batch_length
|
||||
|
||||
|
||||
def _update_batch_accuracy(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray,
|
||||
accuracy: SparseCategoricalAccuracy):
|
||||
"""Updates the accuracy metric per batch.
|
||||
def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray) -> (float, float):
|
||||
"""Returns the number of correct predictions, total number of predictions for a batch.
|
||||
|
||||
Args:
|
||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||
num_tokens, vocab_size).
|
||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||
num_tokens, 1).
|
||||
accuracy: SparseCategoricalAccuracy accuracy metric.
|
||||
"""
|
||||
batch_correct_preds = 0.0
|
||||
batch_total_preds = 0.0
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
accuracy.update_state(sequence_labels.astype(np.float32),
|
||||
sequence_logits.astype(np.float32))
|
||||
preds = sparse_categorical_accuracy(K.constant(sequence_labels),
|
||||
K.constant(sequence_logits))
|
||||
batch_correct_preds += preds.numpy().sum()
|
||||
batch_total_preds += len(sequence_labels)
|
||||
|
||||
return batch_correct_preds, batch_total_preds
|
||||
|
||||
|
||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||
privacy_report_metadata: PrivacyReportMetadata,
|
||||
test_fraction: float = 0.25,
|
||||
balance: bool = True,
|
||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
|
||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
||||
balance: bool = True) -> AttackerData:
|
||||
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
||||
|
||||
Uses logits and losses to generate ranks and performs a random train-test
|
||||
split.
|
||||
|
@ -241,11 +261,11 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|||
|
||||
Args:
|
||||
attack_input_data: Original Seq2SeqAttackInputData
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
balance: Whether the training and test sets for the membership inference
|
||||
attacker should have a balanced (roughly equal) number of samples from the
|
||||
training and test sets used to develop the model under attack.
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
|
@ -276,7 +296,8 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
|
||||
# Populate fields of privacy report metadata
|
||||
# Populate accuracy, loss fields in privacy report metadata
|
||||
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||
privacy_report_metadata.loss_train = loss_train
|
||||
privacy_report_metadata.loss_test = loss_test
|
||||
privacy_report_metadata.accuracy_train = accuracy_train
|
||||
|
@ -310,8 +331,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|||
attacker = models.LogisticRegressionAttacker()
|
||||
|
||||
# Create attacker data and populate fields of privacy_report_metadata
|
||||
if privacy_report_metadata is None:
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||
balance=balance_attacker_training,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
|
|
@ -162,7 +162,6 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||
|
||||
|
||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
|
@ -345,8 +344,8 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
|||
seed=12345),
|
||||
balance_attacker_training=False)
|
||||
metadata = result.privacy_report_metadata
|
||||
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_train, 2.08, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_test, 2.02, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
||||
|
||||
|
|
Loading…
Reference in a new issue