Compute and populate PrivacyReportMetadata fields

This commit is contained in:
amad-person 2020-11-25 16:06:37 +08:00
parent 46bee91cda
commit eb215072bc
3 changed files with 146 additions and 25 deletions

View file

@ -1163,7 +1163,10 @@
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
"\n",
"# Print a user-friendly summary of the attacks\n",
"print(attack_result.summary())"
"print(attack_result.summary())\n",
"\n",
"# Print metadata of the target model\n",
"print(attack_result.privacy_report_metadata)"
]
}
],

View file

@ -25,9 +25,11 @@ from dataclasses import dataclass
from scipy.stats import rankdata
from sklearn import metrics, model_selection
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, AttackResults, \
RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
AttackResults, RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array
@ -116,9 +118,10 @@ class Seq2SeqAttackInputData:
return '\n'.join(result)
def _get_average_ranks(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> np.ndarray:
"""Returns the average rank of tokens in a batch of sequences.
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
"""Returns the average rank of tokens per batch of sequences,
and the loss computed using logits and labels.
Args:
logits: Logits returned by a seq2seq model, dim = (num_batches,
@ -127,18 +130,48 @@ def _get_average_ranks(logits: Iterator[np.ndarray],
num_sequences, num_tokens, 1).
Returns:
An array of average ranks, dim = (num_batches, 1).
1. An array of average ranks, dim = (num_batches, 1).
Each average rank is calculated over ranks of tokens in sequences of a
particular batch.
2. Loss computed over all logits and labels.
3. Accuracy computed over all logits and labels.
"""
ranks = []
loss = SparseCategoricalCrossentropy(from_logits=True)
accuracy = SparseCategoricalAccuracy()
for batch_logits, batch_labels in zip(logits, labels):
batch_ranks = []
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
# Compute average rank for the current batch.
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
ranks.append(np.mean(batch_ranks))
return np.array(ranks)
# Update overall loss with loss of the current batch.
_update_batch_loss(batch_logits, batch_labels, loss)
# Update overall accuracy with accuracy of the current batch.
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
def _get_batch_ranks(batch_logits: np.ndarray,
batch_labels: np.ndarray) -> np.ndarray:
"""Returns the ranks of tokens in a batch of sequences.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
Returns:
An array of ranks of tokens in a batch of sequences, dim = (num_sequences,
num_tokens, 1)
"""
batch_ranks = []
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
return np.array(batch_ranks)
def _get_ranks_for_sequence(logits: np.ndarray,
@ -160,28 +193,67 @@ def _get_ranks_for_sequence(logits: np.ndarray,
return sequence_ranks
def _update_batch_loss(batch_logits: np.ndarray,
batch_labels: np.ndarray,
loss: SparseCategoricalCrossentropy):
"""Updates the loss metric per batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
loss: SparseCategoricalCrossentropy loss metric.
"""
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
loss.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
def _update_batch_accuracy(batch_logits: np.ndarray,
batch_labels: np.ndarray,
accuracy: SparseCategoricalAccuracy):
"""Updates the accuracy metric per batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
accuracy: SparseCategoricalAccuracy accuracy metric.
"""
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
accuracy.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25,
balance: bool = True) -> AttackerData:
balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
"""Prepare Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
split.
Also computes metadata (loss, accuracy) for the model under attack
and populates respective fields of PrivacyReportMetadata.
Args:
attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns:
AttackerData.
"""
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
attack_input_data.labels_train)
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
attack_input_data.labels_test)
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(attack_input_data.logits_train,
attack_input_data.labels_train)
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(attack_input_data.logits_test,
attack_input_data.labels_test)
if balance:
min_size = min(len(attack_input_train), len(attack_input_test))
@ -204,18 +276,24 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate fields of privacy report metadata
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata.accuracy_test = accuracy_test
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test)
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
unused_report_metadata: PrivacyReportMetadata = None,
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults:
"""Runs membership inference attacks on a seq2seq model.
Args:
attack_input: input data for running an attack
unused_report_metadata: the metadata of the model under attack.
privacy_report_metadata: the metadata of the model under attack.
balance_attacker_training: Whether the training and test sets for the
membership inference attacker should have a balanced (roughly equal)
number of samples from the training and test sets used to develop the
@ -231,8 +309,12 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
# as it makes the most sense for single-number features.
attacker = models.LogisticRegressionAttacker()
prepared_attacker_data = create_seq2seq_attacker_data(
attack_input, balance=balance_attacker_training)
# Create attacker data and populate fields of privacy_report_metadata
if privacy_report_metadata is None:
privacy_report_metadata = PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
balance=balance_attacker_training,
privacy_report_metadata=privacy_report_metadata)
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
@ -253,5 +335,6 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
roc_curve=roc_curve)
]
return AttackResults(single_attack_results=attack_results)
return AttackResults(
single_attack_results=attack_results,
privacy_report_metadata=privacy_report_metadata)

View file

@ -17,7 +17,7 @@
from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType, PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \
create_seq2seq_attacker_data, run_seq2seq_attack
@ -145,13 +145,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
vocab_size=3,
train_size=3,
test_size=2)
attacker_data = create_seq2seq_attacker_data(attack_input, 0.25, balance=False)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
test_fraction=0.25,
balance=False,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 3)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
@ -210,13 +220,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
vocab_size=3,
train_size=3,
test_size=3)
attacker_data = create_seq2seq_attacker_data(attack_input, 0.33, balance=True)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
test_fraction=0.33,
balance=True,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 4)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
vocab_size):
@ -315,6 +335,21 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self):
result = run_seq2seq_attack(get_seq2seq_test_input(
n_train=20,
n_test=10,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=3,
seed=12345),
balance_attacker_training=False)
metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
if __name__ == '__main__':
absltest.main()