For seq2seq MIA test, call threshold attacker directly.
PiperOrigin-RevId: 426941426
This commit is contained in:
parent
2fe51d2eeb
commit
ceced43d0b
2 changed files with 21 additions and 254 deletions
|
@ -23,19 +23,12 @@ from typing import Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
from sklearn import metrics
|
|
||||||
from sklearn import model_selection
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
|
||||||
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import _sample_multidimensional_array
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import AttackerData
|
|
||||||
|
|
||||||
|
|
||||||
def _is_iterator(obj, obj_name):
|
def _is_iterator(obj, obj_name):
|
||||||
|
@ -266,66 +259,6 @@ def _get_batch_accuracy_metrics(
|
||||||
return batch_correct_preds, batch_total_preds
|
return batch_correct_preds, batch_total_preds
|
||||||
|
|
||||||
|
|
||||||
def create_seq2seq_attacker_data(
|
|
||||||
attack_input_data: Seq2SeqAttackInputData,
|
|
||||||
test_fraction: float = 0.25,
|
|
||||||
balance: bool = True,
|
|
||||||
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
|
|
||||||
) -> AttackerData:
|
|
||||||
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
|
||||||
|
|
||||||
Uses logits and losses to generate ranks and performs a random train-test
|
|
||||||
split.
|
|
||||||
|
|
||||||
Also computes metadata (loss, accuracy) for the model under attack
|
|
||||||
and populates respective fields of PrivacyReportMetadata.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attack_input_data: Original Seq2SeqAttackInputData
|
|
||||||
test_fraction: Fraction of the dataset to include in the test split.
|
|
||||||
balance: Whether the training and test sets for the membership inference
|
|
||||||
attacker should have a balanced (roughly equal) number of samples from the
|
|
||||||
training and test sets used to develop the model under attack.
|
|
||||||
privacy_report_metadata: the metadata of the model under attack.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AttackerData.
|
|
||||||
"""
|
|
||||||
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
|
|
||||||
attack_input_data.logits_train, attack_input_data.labels_train)
|
|
||||||
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
|
|
||||||
attack_input_data.logits_test, attack_input_data.labels_test)
|
|
||||||
|
|
||||||
if balance:
|
|
||||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
|
||||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
|
||||||
min_size)
|
|
||||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
|
||||||
min_size)
|
|
||||||
|
|
||||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
|
||||||
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
|
|
||||||
|
|
||||||
# Reshape for classifying one-dimensional features
|
|
||||||
features_all = features_all.reshape(-1, 1)
|
|
||||||
|
|
||||||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
|
||||||
|
|
||||||
# Perform a train-test split
|
|
||||||
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
|
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
|
||||||
|
|
||||||
# Populate accuracy, loss fields in privacy report metadata
|
|
||||||
privacy_report_metadata.loss_train = loss_train
|
|
||||||
privacy_report_metadata.loss_test = loss_test
|
|
||||||
privacy_report_metadata.accuracy_train = accuracy_train
|
|
||||||
privacy_report_metadata.accuracy_test = accuracy_test
|
|
||||||
|
|
||||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
|
||||||
is_training_labels_test,
|
|
||||||
DataSize(ntrain=ntrain, ntest=ntest))
|
|
||||||
|
|
||||||
|
|
||||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
privacy_report_metadata: PrivacyReportMetadata = None,
|
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||||
balance_attacker_training: bool = True) -> AttackResults:
|
balance_attacker_training: bool = True) -> AttackResults:
|
||||||
|
@ -343,39 +276,23 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
the attack result.
|
the attack result.
|
||||||
"""
|
"""
|
||||||
attack_input.validate()
|
attack_input.validate()
|
||||||
|
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
|
||||||
|
attack_input.logits_train, attack_input.labels_train)
|
||||||
|
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
|
||||||
|
attack_input.logits_test, attack_input.labels_test)
|
||||||
|
|
||||||
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
|
||||||
# record to determine membership. So only Logistic Regression is supported,
|
|
||||||
# as it makes the most sense for single-number features.
|
|
||||||
attacker = models.LogisticRegressionAttacker()
|
|
||||||
|
|
||||||
# Create attacker data and populate fields of privacy_report_metadata
|
|
||||||
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||||
prepared_attacker_data = create_seq2seq_attacker_data(
|
privacy_report_metadata.loss_train = loss_train
|
||||||
attack_input_data=attack_input,
|
privacy_report_metadata.loss_test = loss_test
|
||||||
balance=balance_attacker_training,
|
privacy_report_metadata.accuracy_train = accuracy_train
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata.accuracy_test = accuracy_test
|
||||||
|
|
||||||
attacker.train_model(prepared_attacker_data.features_train,
|
# `attack_input_train` and `attack_input_test` contains the rank of the
|
||||||
prepared_attacker_data.is_training_labels_train)
|
# ground-truth label in the logit, so smaller value means an example is
|
||||||
|
# more likely a training example.
|
||||||
# Run the attacker on (permuted) test examples.
|
return mia.run_attacks(
|
||||||
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
AttackInputData(
|
||||||
|
loss_train=attack_input_train, loss_test=attack_input_test),
|
||||||
# Generate ROC curves with predictions.
|
attack_types=(AttackType.THRESHOLD_ATTACK,),
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
privacy_report_metadata=privacy_report_metadata,
|
||||||
prepared_attacker_data.is_training_labels_test, predictions_test)
|
balance_attacker_training=balance_attacker_training)
|
||||||
|
|
||||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
|
||||||
|
|
||||||
attack_results = [
|
|
||||||
SingleAttackResult(
|
|
||||||
slice_spec=SingleSliceSpec(),
|
|
||||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
|
||||||
roc_curve=roc_curve,
|
|
||||||
data_size=prepared_attacker_data.data_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
return AttackResults(
|
|
||||||
single_attack_results=attack_results,
|
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
|
||||||
|
|
|
@ -16,8 +16,6 @@ from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
|
||||||
|
|
||||||
|
@ -91,154 +89,6 @@ class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||||
test_size=0).validate)
|
test_size=0).validate)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
|
||||||
attack_input = Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array(
|
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
logits_test=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([0, 1], dtype=np.float32),
|
|
||||||
np.array([1, 2], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_test=iter([
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
vocab_size=3,
|
|
||||||
train_size=3,
|
|
||||||
test_size=2)
|
|
||||||
privacy_report_metadata = PrivacyReportMetadata()
|
|
||||||
attacker_data = create_seq2seq_attacker_data(
|
|
||||||
attack_input_data=attack_input,
|
|
||||||
test_fraction=0.25,
|
|
||||||
balance=False,
|
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
|
||||||
self.assertLen(attacker_data.features_train, 3)
|
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
|
||||||
|
|
||||||
# Tests that fields of PrivacyReportMetadata are populated.
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
|
||||||
|
|
||||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
|
||||||
attack_input = Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array(
|
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
logits_test=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([0, 1], dtype=np.float32),
|
|
||||||
np.array([1, 2], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_test=iter([
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)])
|
|
||||||
]),
|
|
||||||
vocab_size=3,
|
|
||||||
train_size=3,
|
|
||||||
test_size=3)
|
|
||||||
privacy_report_metadata = PrivacyReportMetadata()
|
|
||||||
attacker_data = create_seq2seq_attacker_data(
|
|
||||||
attack_input_data=attack_input,
|
|
||||||
test_fraction=0.33,
|
|
||||||
balance=True,
|
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
|
||||||
self.assertLen(attacker_data.features_train, 4)
|
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
|
||||||
|
|
||||||
# Tests that fields of PrivacyReportMetadata are populated.
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
|
||||||
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||||
vocab_size):
|
vocab_size):
|
||||||
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
||||||
|
@ -323,7 +173,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
max_tokens_in_sequence=5,
|
max_tokens_in_sequence=5,
|
||||||
vocab_size=2))
|
vocab_size=2))
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
self.assertEqual(seq2seq_result.attack_type, AttackType.THRESHOLD_ATTACK)
|
||||||
|
|
||||||
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
||||||
result = run_seq2seq_attack(
|
result = run_seq2seq_attack(
|
||||||
|
@ -337,7 +187,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
balance_attacker_training=False)
|
balance_attacker_training=False)
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
np.testing.assert_almost_equal(
|
np.testing.assert_almost_equal(
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
seq2seq_result.roc_curve.get_auc(), 0.59, decimal=2)
|
||||||
|
|
||||||
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||||
attack_input = Seq2SeqAttackInputData(
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
|
Loading…
Reference in a new issue