Compute and populate PrivacyReportMetadata fields
This commit is contained in:
parent
46bee91cda
commit
eb215072bc
3 changed files with 146 additions and 25 deletions
|
@ -1163,7 +1163,10 @@
|
|||
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
|
||||
"\n",
|
||||
"# Print a user-friendly summary of the attacks\n",
|
||||
"print(attack_result.summary())"
|
||||
"print(attack_result.summary())\n",
|
||||
"\n",
|
||||
"# Print metadata of the target model\n",
|
||||
"print(attack_result.privacy_report_metadata)"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -25,9 +25,11 @@ from dataclasses import dataclass
|
|||
from scipy.stats import rankdata
|
||||
from sklearn import metrics, model_selection
|
||||
|
||||
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, AttackResults, \
|
||||
RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
|
||||
AttackResults, RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array
|
||||
|
||||
|
||||
|
@ -116,9 +118,10 @@ class Seq2SeqAttackInputData:
|
|||
return '\n'.join(result)
|
||||
|
||||
|
||||
def _get_average_ranks(logits: Iterator[np.ndarray],
|
||||
labels: Iterator[np.ndarray]) -> np.ndarray:
|
||||
"""Returns the average rank of tokens in a batch of sequences.
|
||||
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
||||
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
|
||||
"""Returns the average rank of tokens per batch of sequences,
|
||||
and the loss computed using logits and labels.
|
||||
|
||||
Args:
|
||||
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
||||
|
@ -127,18 +130,48 @@ def _get_average_ranks(logits: Iterator[np.ndarray],
|
|||
num_sequences, num_tokens, 1).
|
||||
|
||||
Returns:
|
||||
An array of average ranks, dim = (num_batches, 1).
|
||||
1. An array of average ranks, dim = (num_batches, 1).
|
||||
Each average rank is calculated over ranks of tokens in sequences of a
|
||||
particular batch.
|
||||
2. Loss computed over all logits and labels.
|
||||
3. Accuracy computed over all logits and labels.
|
||||
"""
|
||||
ranks = []
|
||||
loss = SparseCategoricalCrossentropy(from_logits=True)
|
||||
accuracy = SparseCategoricalAccuracy()
|
||||
for batch_logits, batch_labels in zip(logits, labels):
|
||||
batch_ranks = []
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||
# Compute average rank for the current batch.
|
||||
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
||||
ranks.append(np.mean(batch_ranks))
|
||||
|
||||
return np.array(ranks)
|
||||
# Update overall loss with loss of the current batch.
|
||||
_update_batch_loss(batch_logits, batch_labels, loss)
|
||||
|
||||
# Update overall accuracy with accuracy of the current batch.
|
||||
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
|
||||
|
||||
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
|
||||
|
||||
|
||||
def _get_batch_ranks(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray) -> np.ndarray:
|
||||
"""Returns the ranks of tokens in a batch of sequences.
|
||||
|
||||
Args:
|
||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||
num_tokens, vocab_size).
|
||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||
num_tokens, 1).
|
||||
|
||||
Returns:
|
||||
An array of ranks of tokens in a batch of sequences, dim = (num_sequences,
|
||||
num_tokens, 1)
|
||||
"""
|
||||
batch_ranks = []
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||
|
||||
return np.array(batch_ranks)
|
||||
|
||||
|
||||
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||
|
@ -160,28 +193,67 @@ def _get_ranks_for_sequence(logits: np.ndarray,
|
|||
return sequence_ranks
|
||||
|
||||
|
||||
def _update_batch_loss(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray,
|
||||
loss: SparseCategoricalCrossentropy):
|
||||
"""Updates the loss metric per batch.
|
||||
|
||||
Args:
|
||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||
num_tokens, vocab_size).
|
||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||
num_tokens, 1).
|
||||
loss: SparseCategoricalCrossentropy loss metric.
|
||||
"""
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
loss.update_state(sequence_labels.astype(np.float32),
|
||||
sequence_logits.astype(np.float32))
|
||||
|
||||
|
||||
def _update_batch_accuracy(batch_logits: np.ndarray,
|
||||
batch_labels: np.ndarray,
|
||||
accuracy: SparseCategoricalAccuracy):
|
||||
"""Updates the accuracy metric per batch.
|
||||
|
||||
Args:
|
||||
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||
num_tokens, vocab_size).
|
||||
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||
num_tokens, 1).
|
||||
accuracy: SparseCategoricalAccuracy accuracy metric.
|
||||
"""
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
accuracy.update_state(sequence_labels.astype(np.float32),
|
||||
sequence_logits.astype(np.float32))
|
||||
|
||||
|
||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||
test_fraction: float = 0.25,
|
||||
balance: bool = True) -> AttackerData:
|
||||
balance: bool = True,
|
||||
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
|
||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
||||
|
||||
Uses logits and losses to generate ranks and performs a random train-test
|
||||
split.
|
||||
|
||||
Also computes metadata (loss, accuracy) for the model under attack
|
||||
and populates respective fields of PrivacyReportMetadata.
|
||||
|
||||
Args:
|
||||
attack_input_data: Original Seq2SeqAttackInputData
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
balance: Whether the training and test sets for the membership inference
|
||||
attacker should have a balanced (roughly equal) number of samples from the
|
||||
training and test sets used to develop the model under attack.
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
"""
|
||||
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
|
||||
attack_input_data.labels_train)
|
||||
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
|
||||
attack_input_data.labels_test)
|
||||
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(attack_input_data.logits_train,
|
||||
attack_input_data.labels_train)
|
||||
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(attack_input_data.logits_test,
|
||||
attack_input_data.labels_test)
|
||||
|
||||
if balance:
|
||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||
|
@ -204,18 +276,24 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
|
||||
# Populate fields of privacy report metadata
|
||||
privacy_report_metadata.loss_train = loss_train
|
||||
privacy_report_metadata.loss_test = loss_test
|
||||
privacy_report_metadata.accuracy_train = accuracy_train
|
||||
privacy_report_metadata.accuracy_test = accuracy_test
|
||||
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
|
||||
|
||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||
unused_report_metadata: PrivacyReportMetadata = None,
|
||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True) -> AttackResults:
|
||||
"""Runs membership inference attacks on a seq2seq model.
|
||||
|
||||
Args:
|
||||
attack_input: input data for running an attack
|
||||
unused_report_metadata: the metadata of the model under attack.
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
balance_attacker_training: Whether the training and test sets for the
|
||||
membership inference attacker should have a balanced (roughly equal)
|
||||
number of samples from the training and test sets used to develop the
|
||||
|
@ -231,8 +309,12 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|||
# as it makes the most sense for single-number features.
|
||||
attacker = models.LogisticRegressionAttacker()
|
||||
|
||||
prepared_attacker_data = create_seq2seq_attacker_data(
|
||||
attack_input, balance=balance_attacker_training)
|
||||
# Create attacker data and populate fields of privacy_report_metadata
|
||||
if privacy_report_metadata is None:
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||
balance=balance_attacker_training,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
||||
attacker.train_model(prepared_attacker_data.features_train,
|
||||
prepared_attacker_data.is_training_labels_train)
|
||||
|
@ -253,5 +335,6 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|||
roc_curve=roc_curve)
|
||||
]
|
||||
|
||||
return AttackResults(single_attack_results=attack_results)
|
||||
|
||||
return AttackResults(
|
||||
single_attack_results=attack_results,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType, PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \
|
||||
create_seq2seq_attacker_data, run_seq2seq_attack
|
||||
|
||||
|
@ -145,13 +145,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=2)
|
||||
attacker_data = create_seq2seq_attacker_data(attack_input, 0.25, balance=False)
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||
test_fraction=0.25,
|
||||
balance=False,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
self.assertLen(attacker_data.features_train, 3)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
# Tests that fields of PrivacyReportMetadata are populated.
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||
|
||||
|
||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
|
@ -210,13 +220,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=3)
|
||||
attacker_data = create_seq2seq_attacker_data(attack_input, 0.33, balance=True)
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||
test_fraction=0.33,
|
||||
balance=True,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
self.assertLen(attacker_data.features_train, 4)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
# Tests that fields of PrivacyReportMetadata are populated.
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||
|
||||
|
||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||
vocab_size):
|
||||
|
@ -315,6 +335,21 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
|||
np.testing.assert_almost_equal(
|
||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||
|
||||
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||
result = run_seq2seq_attack(get_seq2seq_test_input(
|
||||
n_train=20,
|
||||
n_test=10,
|
||||
max_seq_in_batch=3,
|
||||
max_tokens_in_sequence=5,
|
||||
vocab_size=3,
|
||||
seed=12345),
|
||||
balance_attacker_training=False)
|
||||
metadata = result.privacy_report_metadata
|
||||
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
||||
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue