Compute and populate PrivacyReportMetadata fields
This commit is contained in:
parent
46bee91cda
commit
eb215072bc
3 changed files with 146 additions and 25 deletions
|
@ -1163,7 +1163,10 @@
|
||||||
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
|
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print a user-friendly summary of the attacks\n",
|
"# Print a user-friendly summary of the attacks\n",
|
||||||
"print(attack_result.summary())"
|
"print(attack_result.summary())\n",
|
||||||
|
"\n",
|
||||||
|
"# Print metadata of the target model\n",
|
||||||
|
"print(attack_result.privacy_report_metadata)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -25,9 +25,11 @@ from dataclasses import dataclass
|
||||||
from scipy.stats import rankdata
|
from scipy.stats import rankdata
|
||||||
from sklearn import metrics, model_selection
|
from sklearn import metrics, model_selection
|
||||||
|
|
||||||
|
from tensorflow.keras.metrics import SparseCategoricalCrossentropy, SparseCategoricalAccuracy
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, AttackResults, \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata, \
|
||||||
RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
|
AttackResults, RocCurve, SingleAttackResult, SingleSliceSpec, AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array
|
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData, _sample_multidimensional_array
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,9 +118,10 @@ class Seq2SeqAttackInputData:
|
||||||
return '\n'.join(result)
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
def _get_average_ranks(logits: Iterator[np.ndarray],
|
def _get_attack_features_and_metadata(logits: Iterator[np.ndarray],
|
||||||
labels: Iterator[np.ndarray]) -> np.ndarray:
|
labels: Iterator[np.ndarray]) -> (np.ndarray, float):
|
||||||
"""Returns the average rank of tokens in a batch of sequences.
|
"""Returns the average rank of tokens per batch of sequences,
|
||||||
|
and the loss computed using logits and labels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
||||||
|
@ -127,18 +130,48 @@ def _get_average_ranks(logits: Iterator[np.ndarray],
|
||||||
num_sequences, num_tokens, 1).
|
num_sequences, num_tokens, 1).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An array of average ranks, dim = (num_batches, 1).
|
1. An array of average ranks, dim = (num_batches, 1).
|
||||||
Each average rank is calculated over ranks of tokens in sequences of a
|
Each average rank is calculated over ranks of tokens in sequences of a
|
||||||
particular batch.
|
particular batch.
|
||||||
|
2. Loss computed over all logits and labels.
|
||||||
|
3. Accuracy computed over all logits and labels.
|
||||||
"""
|
"""
|
||||||
ranks = []
|
ranks = []
|
||||||
|
loss = SparseCategoricalCrossentropy(from_logits=True)
|
||||||
|
accuracy = SparseCategoricalAccuracy()
|
||||||
for batch_logits, batch_labels in zip(logits, labels):
|
for batch_logits, batch_labels in zip(logits, labels):
|
||||||
batch_ranks = []
|
# Compute average rank for the current batch.
|
||||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
||||||
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
|
||||||
ranks.append(np.mean(batch_ranks))
|
ranks.append(np.mean(batch_ranks))
|
||||||
|
|
||||||
return np.array(ranks)
|
# Update overall loss with loss of the current batch.
|
||||||
|
_update_batch_loss(batch_logits, batch_labels, loss)
|
||||||
|
|
||||||
|
# Update overall accuracy with accuracy of the current batch.
|
||||||
|
_update_batch_accuracy(batch_logits, batch_labels, accuracy)
|
||||||
|
|
||||||
|
return np.array(ranks), loss.result().numpy(), accuracy.result().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_ranks(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray) -> np.ndarray:
|
||||||
|
"""Returns the ranks of tokens in a batch of sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of ranks of tokens in a batch of sequences, dim = (num_sequences,
|
||||||
|
num_tokens, 1)
|
||||||
|
"""
|
||||||
|
batch_ranks = []
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||||
|
|
||||||
|
return np.array(batch_ranks)
|
||||||
|
|
||||||
|
|
||||||
def _get_ranks_for_sequence(logits: np.ndarray,
|
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||||
|
@ -160,28 +193,67 @@ def _get_ranks_for_sequence(logits: np.ndarray,
|
||||||
return sequence_ranks
|
return sequence_ranks
|
||||||
|
|
||||||
|
|
||||||
|
def _update_batch_loss(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray,
|
||||||
|
loss: SparseCategoricalCrossentropy):
|
||||||
|
"""Updates the loss metric per batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
loss: SparseCategoricalCrossentropy loss metric.
|
||||||
|
"""
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
loss.update_state(sequence_labels.astype(np.float32),
|
||||||
|
sequence_logits.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
def _update_batch_accuracy(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray,
|
||||||
|
accuracy: SparseCategoricalAccuracy):
|
||||||
|
"""Updates the accuracy metric per batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
accuracy: SparseCategoricalAccuracy accuracy metric.
|
||||||
|
"""
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
accuracy.update_state(sequence_labels.astype(np.float32),
|
||||||
|
sequence_logits.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
test_fraction: float = 0.25,
|
test_fraction: float = 0.25,
|
||||||
balance: bool = True) -> AttackerData:
|
balance: bool = True,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata = None) -> AttackerData:
|
||||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
||||||
|
|
||||||
Uses logits and losses to generate ranks and performs a random train-test
|
Uses logits and losses to generate ranks and performs a random train-test
|
||||||
split.
|
split.
|
||||||
|
|
||||||
|
Also computes metadata (loss, accuracy) for the model under attack
|
||||||
|
and populates respective fields of PrivacyReportMetadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attack_input_data: Original Seq2SeqAttackInputData
|
attack_input_data: Original Seq2SeqAttackInputData
|
||||||
test_fraction: Fraction of the dataset to include in the test split.
|
test_fraction: Fraction of the dataset to include in the test split.
|
||||||
balance: Whether the training and test sets for the membership inference
|
balance: Whether the training and test sets for the membership inference
|
||||||
attacker should have a balanced (roughly equal) number of samples from the
|
attacker should have a balanced (roughly equal) number of samples from the
|
||||||
training and test sets used to develop the model under attack.
|
training and test sets used to develop the model under attack.
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AttackerData.
|
AttackerData.
|
||||||
"""
|
"""
|
||||||
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
|
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(attack_input_data.logits_train,
|
||||||
attack_input_data.labels_train)
|
attack_input_data.labels_train)
|
||||||
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
|
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(attack_input_data.logits_test,
|
||||||
attack_input_data.labels_test)
|
attack_input_data.labels_test)
|
||||||
|
|
||||||
if balance:
|
if balance:
|
||||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||||
|
@ -204,18 +276,24 @@ def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
model_selection.train_test_split(
|
model_selection.train_test_split(
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
|
|
||||||
|
# Populate fields of privacy report metadata
|
||||||
|
privacy_report_metadata.loss_train = loss_train
|
||||||
|
privacy_report_metadata.loss_test = loss_test
|
||||||
|
privacy_report_metadata.accuracy_train = accuracy_train
|
||||||
|
privacy_report_metadata.accuracy_test = accuracy_test
|
||||||
|
|
||||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||||
is_training_labels_test)
|
is_training_labels_test)
|
||||||
|
|
||||||
|
|
||||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
unused_report_metadata: PrivacyReportMetadata = None,
|
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||||
balance_attacker_training: bool = True) -> AttackResults:
|
balance_attacker_training: bool = True) -> AttackResults:
|
||||||
"""Runs membership inference attacks on a seq2seq model.
|
"""Runs membership inference attacks on a seq2seq model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attack_input: input data for running an attack
|
attack_input: input data for running an attack
|
||||||
unused_report_metadata: the metadata of the model under attack.
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
balance_attacker_training: Whether the training and test sets for the
|
balance_attacker_training: Whether the training and test sets for the
|
||||||
membership inference attacker should have a balanced (roughly equal)
|
membership inference attacker should have a balanced (roughly equal)
|
||||||
number of samples from the training and test sets used to develop the
|
number of samples from the training and test sets used to develop the
|
||||||
|
@ -231,8 +309,12 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
# as it makes the most sense for single-number features.
|
# as it makes the most sense for single-number features.
|
||||||
attacker = models.LogisticRegressionAttacker()
|
attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
||||||
prepared_attacker_data = create_seq2seq_attacker_data(
|
# Create attacker data and populate fields of privacy_report_metadata
|
||||||
attack_input, balance=balance_attacker_training)
|
if privacy_report_metadata is None:
|
||||||
|
privacy_report_metadata = PrivacyReportMetadata()
|
||||||
|
prepared_attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||||
|
balance=balance_attacker_training,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
||||||
attacker.train_model(prepared_attacker_data.features_train,
|
attacker.train_model(prepared_attacker_data.features_train,
|
||||||
prepared_attacker_data.is_training_labels_train)
|
prepared_attacker_data.is_training_labels_train)
|
||||||
|
@ -253,5 +335,6 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
roc_curve=roc_curve)
|
roc_curve=roc_curve)
|
||||||
]
|
]
|
||||||
|
|
||||||
return AttackResults(single_attack_results=attack_results)
|
return AttackResults(
|
||||||
|
single_attack_results=attack_results,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType, PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \
|
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \
|
||||||
create_seq2seq_attacker_data, run_seq2seq_attack
|
create_seq2seq_attacker_data, run_seq2seq_attack
|
||||||
|
|
||||||
|
@ -145,13 +145,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
vocab_size=3,
|
vocab_size=3,
|
||||||
train_size=3,
|
train_size=3,
|
||||||
test_size=2)
|
test_size=2)
|
||||||
attacker_data = create_seq2seq_attacker_data(attack_input, 0.25, balance=False)
|
privacy_report_metadata = PrivacyReportMetadata()
|
||||||
|
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||||
|
test_fraction=0.25,
|
||||||
|
balance=False,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
self.assertLen(attacker_data.features_train, 3)
|
self.assertLen(attacker_data.features_train, 3)
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
|
# Tests that fields of PrivacyReportMetadata are populated.
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||||
|
|
||||||
|
|
||||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
attack_input = Seq2SeqAttackInputData(
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
@ -210,13 +220,23 @@ class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
vocab_size=3,
|
vocab_size=3,
|
||||||
train_size=3,
|
train_size=3,
|
||||||
test_size=3)
|
test_size=3)
|
||||||
attacker_data = create_seq2seq_attacker_data(attack_input, 0.33, balance=True)
|
privacy_report_metadata = PrivacyReportMetadata()
|
||||||
|
attacker_data = create_seq2seq_attacker_data(attack_input_data=attack_input,
|
||||||
|
test_fraction=0.33,
|
||||||
|
balance=True,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
self.assertLen(attacker_data.features_train, 4)
|
self.assertLen(attacker_data.features_train, 4)
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
|
# Tests that fields of PrivacyReportMetadata are populated.
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||||
vocab_size):
|
vocab_size):
|
||||||
|
@ -315,6 +335,21 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
np.testing.assert_almost_equal(
|
np.testing.assert_almost_equal(
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||||
|
result = run_seq2seq_attack(get_seq2seq_test_input(
|
||||||
|
n_train=20,
|
||||||
|
n_test=10,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=3,
|
||||||
|
seed=12345),
|
||||||
|
balance_attacker_training=False)
|
||||||
|
metadata = result.privacy_report_metadata
|
||||||
|
np.testing.assert_almost_equal(metadata.loss_train, 1.11, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.loss_test, 1.10, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.accuracy_train, 0.40, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.accuracy_test, 0.34, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue