For seq2seq MIA test, call threshold attacker directly.
PiperOrigin-RevId: 426941426
This commit is contained in:
parent
2fe51d2eeb
commit
ceced43d0b
2 changed files with 21 additions and 254 deletions
|
@ -23,19 +23,12 @@ from typing import Iterator, List, Optional, Tuple
|
|||
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from sklearn import metrics
|
||||
from sklearn import model_selection
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import _sample_multidimensional_array
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import AttackerData
|
||||
|
||||
|
||||
def _is_iterator(obj, obj_name):
|
||||
|
@ -266,66 +259,6 @@ def _get_batch_accuracy_metrics(
|
|||
return batch_correct_preds, batch_total_preds
|
||||
|
||||
|
||||
def create_seq2seq_attacker_data(
|
||||
attack_input_data: Seq2SeqAttackInputData,
|
||||
test_fraction: float = 0.25,
|
||||
balance: bool = True,
|
||||
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
|
||||
) -> AttackerData:
|
||||
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
||||
|
||||
Uses logits and losses to generate ranks and performs a random train-test
|
||||
split.
|
||||
|
||||
Also computes metadata (loss, accuracy) for the model under attack
|
||||
and populates respective fields of PrivacyReportMetadata.
|
||||
|
||||
Args:
|
||||
attack_input_data: Original Seq2SeqAttackInputData
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
balance: Whether the training and test sets for the membership inference
|
||||
attacker should have a balanced (roughly equal) number of samples from the
|
||||
training and test sets used to develop the model under attack.
|
||||
privacy_report_metadata: the metadata of the model under attack.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
"""
|
||||
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
|
||||
attack_input_data.logits_train, attack_input_data.labels_train)
|
||||
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
|
||||
attack_input_data.logits_test, attack_input_data.labels_test)
|
||||
|
||||
if balance:
|
||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||
min_size)
|
||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||
min_size)
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
||||
|
||||
# Reshape for classifying one-dimensional features
|
||||
features_all = features_all.reshape(-1, 1)
|
||||
|
||||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
|
||||
# Populate accuracy, loss fields in privacy report metadata
|
||||
privacy_report_metadata.loss_train = loss_train
|
||||
privacy_report_metadata.loss_test = loss_test
|
||||
privacy_report_metadata.accuracy_train = accuracy_train
|
||||
privacy_report_metadata.accuracy_test = accuracy_test
|
||||
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test,
|
||||
DataSize(ntrain=ntrain, ntest=ntest))
|
||||
|
||||
|
||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True) -> AttackResults:
|
||||
|
@ -343,39 +276,23 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|||
the attack result.
|
||||
"""
|
||||
attack_input.validate()
|
||||
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
|
||||
attack_input.logits_train, attack_input.labels_train)
|
||||
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
|
||||
attack_input.logits_test, attack_input.labels_test)
|
||||
|
||||
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
||||
# record to determine membership. So only Logistic Regression is supported,
|
||||
# as it makes the most sense for single-number features.
|
||||
attacker = models.LogisticRegressionAttacker()
|
||||
|
||||
# Create attacker data and populate fields of privacy_report_metadata
|
||||
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||
prepared_attacker_data = create_seq2seq_attacker_data(
|
||||
attack_input_data=attack_input,
|
||||
balance=balance_attacker_training,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
privacy_report_metadata.loss_train = loss_train
|
||||
privacy_report_metadata.loss_test = loss_test
|
||||
privacy_report_metadata.accuracy_train = accuracy_train
|
||||
privacy_report_metadata.accuracy_test = accuracy_test
|
||||
|
||||
attacker.train_model(prepared_attacker_data.features_train,
|
||||
prepared_attacker_data.is_training_labels_train)
|
||||
|
||||
# Run the attacker on (permuted) test examples.
|
||||
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||
|
||||
# Generate ROC curves with predictions.
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
attack_results = [
|
||||
SingleAttackResult(
|
||||
slice_spec=SingleSliceSpec(),
|
||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||
roc_curve=roc_curve,
|
||||
data_size=prepared_attacker_data.data_size)
|
||||
]
|
||||
|
||||
return AttackResults(
|
||||
single_attack_results=attack_results,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
# `attack_input_train` and `attack_input_test` contains the rank of the
|
||||
# ground-truth label in the logit, so smaller value means an example is
|
||||
# more likely a training example.
|
||||
return mia.run_attacks(
|
||||
AttackInputData(
|
||||
loss_train=attack_input_train, loss_test=attack_input_test),
|
||||
attack_types=(AttackType.THRESHOLD_ATTACK,),
|
||||
privacy_report_metadata=privacy_report_metadata,
|
||||
balance_attacker_training=balance_attacker_training)
|
||||
|
|
|
@ -16,8 +16,6 @@ from absl.testing import absltest
|
|||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
|
||||
|
||||
|
@ -91,154 +89,6 @@ class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
|||
test_size=0).validate)
|
||||
|
||||
|
||||
class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||
|
||||
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=2)
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
attacker_data = create_seq2seq_attacker_data(
|
||||
attack_input_data=attack_input,
|
||||
test_fraction=0.25,
|
||||
balance=False,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
self.assertLen(attacker_data.features_train, 3)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
# Tests that fields of PrivacyReportMetadata are populated.
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||
|
||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([2, 1], dtype=np.float32)])
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=3)
|
||||
privacy_report_metadata = PrivacyReportMetadata()
|
||||
attacker_data = create_seq2seq_attacker_data(
|
||||
attack_input_data=attack_input,
|
||||
test_fraction=0.33,
|
||||
balance=True,
|
||||
privacy_report_metadata=privacy_report_metadata)
|
||||
self.assertLen(attacker_data.features_train, 4)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
# Tests that fields of PrivacyReportMetadata are populated.
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||
|
||||
|
||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||
vocab_size):
|
||||
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
||||
|
@ -323,7 +173,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
|||
max_tokens_in_sequence=5,
|
||||
vocab_size=2))
|
||||
seq2seq_result = list(result.single_attack_results)[0]
|
||||
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||
self.assertEqual(seq2seq_result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||
|
||||
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
||||
result = run_seq2seq_attack(
|
||||
|
@ -337,7 +187,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
|||
balance_attacker_training=False)
|
||||
seq2seq_result = list(result.single_attack_results)[0]
|
||||
np.testing.assert_almost_equal(
|
||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||
seq2seq_result.roc_curve.get_auc(), 0.59, decimal=2)
|
||||
|
||||
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
|
|
Loading…
Reference in a new issue