Compute and populate PrivacyReportMetadata fields

This commit is contained in:
amad-person 2020-11-25 16:06:37 +08:00
parent 46bee91cda
commit eb215072bc
3 changed files with 146 additions and 25 deletions

View file

@ -1163,7 +1163,10 @@
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n", "fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
"\n", "\n",
"# Print a user-friendly summary of the attacks\n", "# Print a user-friendly summary of the attacks\n",
"print(attack_result.summary())" "print(attack_result.summary())\n",
"\n",
"# Print metadata of the target model\n",
"print(attack_result.privacy_report_metadata)"
] ]
} }
], ],

View file

@ -25,9 +25,11 @@ from dataclasses import dataclass
from scipy.stats import rankdata from scipy.stats import rankdata
from sklearn import metrics, model_selection from sklearn import metrics, model_selection
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, AttackResults, \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
RocCurve, SingleAttackResult, SingleSliceSpec, AttackType AttackResults, RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array
@ -116,9 +118,10 @@ class Seq2SeqAttackInputData:
return '\n'.join(result) return '\n'.join(result)
def _get_average_ranks(logits: Iterator[np.ndarray], def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> np.ndarray: labels: Iterator[np.ndarray]) -> (np.ndarray, float):
"""Returns the average rank of tokens in a batch of sequences. """Returns the average rank of tokens per batch of sequences,
and the loss computed using logits and labels.
Args: Args:
logits: Logits returned by a seq2seq model, dim = (num_batches, logits: Logits returned by a seq2seq model, dim = (num_batches,
@ -127,18 +130,48 @@ def _get_average_ranks(logits: Iterator[np.ndarray],
num_sequences, num_tokens, 1). num_sequences, num_tokens, 1).
Returns: Returns:
An array of average ranks, dim = (num_batches, 1). 1. An array of average ranks, dim = (num_batches, 1).
Each average rank is calculated over ranks of tokens in sequences of a Each average rank is calculated over ranks of tokens in sequences of a
particular batch. particular batch.
2. Loss computed over all logits and labels.
3. Accuracy computed over all logits and labels.
""" """
ranks = [] ranks = []
loss = SparseCategoricalCrossentropy(from_logits=True)
accuracy = SparseCategoricalAccuracy()
for batch_logits, batch_labels in zip(logits, labels): for batch_logits, batch_labels in zip(logits, labels):
batch_ranks = [] # Compute average rank for the current batch.
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels): batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
ranks.append(np.mean(batch_ranks)) ranks.append(np.mean(batch_ranks))
return np.array(ranks) # Update overall loss with loss of the current batch.
_update_batch_loss(batch_logits, batch_labels, loss)
# Update overall accuracy with accuracy of the current batch.
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
def _get_batch_ranks(batch_logits: np.ndarray,
batch_labels: np.ndarray) -> np.ndarray:
"""Returns the ranks of tokens in a batch of sequences.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
Returns:
An array of ranks of tokens in a batch of sequences, dim = (num_sequences,
num_tokens, 1)
"""
batch_ranks = []
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
return np.array(batch_ranks)
def _get_ranks_for_sequence(logits: np.ndarray, def _get_ranks_for_sequence(logits: np.ndarray,
@ -160,28 +193,67 @@ def _get_ranks_for_sequence(logits: np.ndarray,
return sequence_ranks return sequence_ranks
def _update_batch_loss(batch_logits: np.ndarray,
batch_labels: np.ndarray,
loss: SparseCategoricalCrossentropy):
"""Updates the loss metric per batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
loss: SparseCategoricalCrossentropy loss metric.
"""
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
loss.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
def _update_batch_accuracy(batch_logits: np.ndarray,
batch_labels: np.ndarray,
accuracy: SparseCategoricalAccuracy):
"""Updates the accuracy metric per batch.
Args:
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
num_tokens, vocab_size).
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
num_tokens, 1).
accuracy: SparseCategoricalAccuracy accuracy metric.
"""
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
accuracy.update_state(sequence_labels.astype(np.float32),
sequence_logits.astype(np.float32))
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData, def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25, test_fraction: float = 0.25,
balance: bool = True) -> AttackerData: balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
"""Prepare Seq2SeqAttackInputData to train ML attackers. """Prepare Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test Uses logits and losses to generate ranks and performs a random train-test
split. split.
Also computes metadata (loss, accuracy) for the model under attack
and populates respective fields of PrivacyReportMetadata.
Args: Args:
attack_input_data: Original Seq2SeqAttackInputData attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split. test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack. training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns: Returns:
AttackerData. AttackerData.
""" """
attack_input_train = _get_average_ranks(attack_input_data.logits_train, attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(attack_input_data.logits_train,
attack_input_data.labels_train) attack_input_data.labels_train)
attack_input_test = _get_average_ranks(attack_input_data.logits_test, attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(attack_input_data.logits_test,
attack_input_data.labels_test) attack_input_data.labels_test)
if balance: if balance:
min_size = min(len(attack_input_train), len(attack_input_test)) min_size = min(len(attack_input_train), len(attack_input_test))
@ -204,18 +276,24 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
model_selection.train_test_split( model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all) features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate fields of privacy report metadata
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata.accuracy_test = accuracy_test
return AttackerData(features_train, is_training_labels_train, features_test, return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test) is_training_labels_test)
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
unused_report_metadata: PrivacyReportMetadata = None, privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults: balance_attacker_training: bool = True) -> AttackResults:
"""Runs membership inference attacks on a seq2seq model. """Runs membership inference attacks on a seq2seq model.
Args: Args:
attack_input: input data for running an attack attack_input: input data for running an attack
unused_report_metadata: the metadata of the model under attack. privacy_report_metadata: the metadata of the model under attack.
balance_attacker_training: Whether the training and test sets for the balance_attacker_training: Whether the training and test sets for the
membership inference attacker should have a balanced (roughly equal) membership inference attacker should have a balanced (roughly equal)
number of samples from the training and test sets used to develop the number of samples from the training and test sets used to develop the
@ -231,8 +309,12 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
# as it makes the most sense for single-number features. # as it makes the most sense for single-number features.
attacker = models.LogisticRegressionAttacker() attacker = models.LogisticRegressionAttacker()
prepared_attacker_data = create_seq2seq_attacker_data( # Create attacker data and populate fields of privacy_report_metadata
attack_input, balance=balance_attacker_training) if privacy_report_metadata is None:
privacy_report_metadata = PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
balance=balance_attacker_training,
privacy_report_metadata=privacy_report_metadata)
attacker.train_model(prepared_attacker_data.features_train, attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train) prepared_attacker_data.is_training_labels_train)
@ -253,5 +335,6 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
roc_curve=roc_curve) roc_curve=roc_curve)
] ]
return AttackResults(single_attack_results=attack_results) return AttackResults(
single_attack_results=attack_results,
privacy_report_metadata=privacy_report_metadata)

View file

@ -17,7 +17,7 @@
from absl.testing import absltest from absl.testing import absltest
import numpy as np import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType, PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \ from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \
create_seq2seq_attacker_data, run_seq2seq_attack create_seq2seq_attacker_data, run_seq2seq_attack
@ -145,13 +145,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
vocab_size=3, vocab_size=3,
train_size=3, train_size=3,
test_size=2) test_size=2)
attacker_data = create_seq2seq_attacker_data(attack_input, 0.25, balance=False) privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
test_fraction=0.25,
balance=False,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 3) self.assertLen(attacker_data.features_train, 3)
self.assertLen(attacker_data.features_test, 2) self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train): for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self): def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData( attack_input = Seq2SeqAttackInputData(
@ -210,13 +220,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
vocab_size=3, vocab_size=3,
train_size=3, train_size=3,
test_size=3) test_size=3)
attacker_data = create_seq2seq_attacker_data(attack_input, 0.33, balance=True) privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
test_fraction=0.33,
balance=True,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 4) self.assertLen(attacker_data.features_train, 4)
self.assertLen(attacker_data.features_test, 2) self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train): for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
vocab_size): vocab_size):
@ -315,6 +335,21 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self):
result = run_seq2seq_attack(get_seq2seq_test_input(
n_train=20,
n_test=10,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=3,
seed=12345),
balance_attacker_training=False)
metadata = result.privacy_report_metadata
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()