For seq2seq MIA test, call threshold attacker directly.

PiperOrigin-RevId: 426941426
This commit is contained in:
Shuang Song 2022-02-07 09:45:36 -08:00 committed by A. Unique TensorFlower
parent 2fe51d2eeb
commit ceced43d0b
2 changed files with 21 additions and 254 deletions

View file

@ -23,19 +23,12 @@ from typing import Iterator, List, Optional, Tuple
import numpy as np
from scipy import stats
from sklearn import metrics
from sklearn import model_selection
import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import _sample_multidimensional_array
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import AttackerData
def _is_iterator(obj, obj_name):
@ -266,66 +259,6 @@ def _get_batch_accuracy_metrics(
return batch_correct_preds, batch_total_preds
def create_seq2seq_attacker_data(
attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25,
balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
) -> AttackerData:
"""Prepares Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
split.
Also computes metadata (loss, accuracy) for the model under attack
and populates respective fields of PrivacyReportMetadata.
Args:
attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns:
AttackerData.
"""
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
attack_input_data.logits_train, attack_input_data.labels_train)
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
attack_input_data.logits_test, attack_input_data.labels_test)
if balance:
min_size = min(len(attack_input_train), len(attack_input_test))
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
features_all = np.concatenate((attack_input_train, attack_input_test))
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
# Reshape for classifying one-dimensional features
features_all = features_all.reshape(-1, 1)
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
# Perform a train-test split
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata.accuracy_test = accuracy_test
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test,
DataSize(ntrain=ntrain, ntest=ntest))
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults:
@ -343,39 +276,23 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
the attack result.
"""
attack_input.validate()
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
attack_input.logits_train, attack_input.labels_train)
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
attack_input.logits_test, attack_input.labels_test)
# The attacker uses the average rank (a single number) of a seq2seq dataset
# record to determine membership. So only Logistic Regression is supported,
# as it makes the most sense for single-number features.
attacker = models.LogisticRegressionAttacker()
# Create attacker data and populate fields of privacy_report_metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data(
attack_input_data=attack_input,
balance=balance_attacker_training,
privacy_report_metadata=privacy_report_metadata)
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata.accuracy_test = accuracy_test
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
# Run the attacker on (permuted) test examples.
predictions_test = attacker.predict(prepared_attacker_data.features_test)
# Generate ROC curves with predictions.
fpr, tpr, thresholds = metrics.roc_curve(
prepared_attacker_data.is_training_labels_test, predictions_test)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
attack_results = [
SingleAttackResult(
slice_spec=SingleSliceSpec(),
attack_type=AttackType.LOGISTIC_REGRESSION,
roc_curve=roc_curve,
data_size=prepared_attacker_data.data_size)
]
return AttackResults(
single_attack_results=attack_results,
privacy_report_metadata=privacy_report_metadata)
# `attack_input_train` and `attack_input_test` contains the rank of the
# ground-truth label in the logit, so smaller value means an example is
# more likely a training example.
return mia.run_attacks(
AttackInputData(
loss_train=attack_input_train, loss_test=attack_input_test),
attack_types=(AttackType.THRESHOLD_ATTACK,),
privacy_report_metadata=privacy_report_metadata,
balance_attacker_training=balance_attacker_training)

View file

@ -16,8 +16,6 @@ from absl.testing import absltest
import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
@ -91,154 +89,6 @@ class Seq2SeqAttackInputDataTest(absltest.TestCase):
test_size=0).validate)
class Seq2SeqTrainedAttackerTest(absltest.TestCase):
def test_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object)
]),
vocab_size=3,
train_size=3,
test_size=2)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(
attack_input_data=attack_input,
test_fraction=0.25,
balance=False,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 3)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([2, 1], dtype=np.float32)])
]),
vocab_size=3,
train_size=3,
test_size=3)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(
attack_input_data=attack_input,
test_fraction=0.33,
balance=True,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 4)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
vocab_size):
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
@ -323,7 +173,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
max_tokens_in_sequence=5,
vocab_size=2))
seq2seq_result = list(result.single_attack_results)[0]
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
self.assertEqual(seq2seq_result.attack_type, AttackType.THRESHOLD_ATTACK)
def test_run_seq2seq_attack_calculates_correct_auc(self):
result = run_seq2seq_attack(
@ -337,7 +187,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
balance_attacker_training=False)
seq2seq_result = list(result.single_attack_results)[0]
np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
seq2seq_result.roc_curve.get_auc(), 0.59, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self):
attack_input = Seq2SeqAttackInputData(