For seq2seq MIA test, call threshold attacker directly.

PiperOrigin-RevId: 426941426
This commit is contained in:
Shuang Song 2022-02-07 09:45:36 -08:00 committed by A. Unique TensorFlower
parent 2fe51d2eeb
commit ceced43d0b
2 changed files with 21 additions and 254 deletions

View file

@ -23,19 +23,12 @@ from typing import Iterator, List, Optional, Tuple
import numpy as np import numpy as np
from scipy import stats from scipy import stats
from sklearn import metrics
from sklearn import model_selection
import tensorflow as tf import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import DataSize
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import _sample_multidimensional_array
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.models import AttackerData
def _is_iterator(obj, obj_name): def _is_iterator(obj, obj_name):
@ -266,66 +259,6 @@ def _get_batch_accuracy_metrics(
return batch_correct_preds, batch_total_preds return batch_correct_preds, batch_total_preds
def create_seq2seq_attacker_data(
attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25,
balance: bool = True,
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
) -> AttackerData:
"""Prepares Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
split.
Also computes metadata (loss, accuracy) for the model under attack
and populates respective fields of PrivacyReportMetadata.
Args:
attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
privacy_report_metadata: the metadata of the model under attack.
Returns:
AttackerData.
"""
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
attack_input_data.logits_train, attack_input_data.labels_train)
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
attack_input_data.logits_test, attack_input_data.labels_test)
if balance:
min_size = min(len(attack_input_train), len(attack_input_test))
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
features_all = np.concatenate((attack_input_train, attack_input_test))
ntrain, ntest = attack_input_train.shape[0], attack_input_test.shape[0]
# Reshape for classifying one-dimensional features
features_all = features_all.reshape(-1, 1)
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
# Perform a train-test split
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate accuracy, loss fields in privacy report metadata
privacy_report_metadata.loss_train = loss_train
privacy_report_metadata.loss_test = loss_test
privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata.accuracy_test = accuracy_test
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test,
DataSize(ntrain=ntrain, ntest=ntest))
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
privacy_report_metadata: PrivacyReportMetadata = None, privacy_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults: balance_attacker_training: bool = True) -> AttackResults:
@ -343,39 +276,23 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
the attack result. the attack result.
""" """
attack_input.validate() attack_input.validate()
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
attack_input.logits_train, attack_input.labels_train)
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
attack_input.logits_test, attack_input.labels_test)
# The attacker uses the average rank (a single number) of a seq2seq dataset
# record to determine membership. So only Logistic Regression is supported,
# as it makes the most sense for single-number features.
attacker = models.LogisticRegressionAttacker()
# Create attacker data and populate fields of privacy_report_metadata
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata() privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
prepared_attacker_data = create_seq2seq_attacker_data( privacy_report_metadata.loss_train = loss_train
attack_input_data=attack_input, privacy_report_metadata.loss_test = loss_test
balance=balance_attacker_training, privacy_report_metadata.accuracy_train = accuracy_train
privacy_report_metadata=privacy_report_metadata) privacy_report_metadata.accuracy_test = accuracy_test
attacker.train_model(prepared_attacker_data.features_train, # `attack_input_train` and `attack_input_test` contains the rank of the
prepared_attacker_data.is_training_labels_train) # ground-truth label in the logit, so smaller value means an example is
# more likely a training example.
# Run the attacker on (permuted) test examples. return mia.run_attacks(
predictions_test = attacker.predict(prepared_attacker_data.features_test) AttackInputData(
loss_train=attack_input_train, loss_test=attack_input_test),
# Generate ROC curves with predictions. attack_types=(AttackType.THRESHOLD_ATTACK,),
fpr, tpr, thresholds = metrics.roc_curve( privacy_report_metadata=privacy_report_metadata,
prepared_attacker_data.is_training_labels_test, predictions_test) balance_attacker_training=balance_attacker_training)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
attack_results = [
SingleAttackResult(
slice_spec=SingleSliceSpec(),
attack_type=AttackType.LOGISTIC_REGRESSION,
roc_curve=roc_curve,
data_size=prepared_attacker_data.data_size)
]
return AttackResults(
single_attack_results=attack_results,
privacy_report_metadata=privacy_report_metadata)

View file

@ -16,8 +16,6 @@ from absl.testing import absltest
import numpy as np import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
@ -91,154 +89,6 @@ class Seq2SeqAttackInputDataTest(absltest.TestCase):
test_size=0).validate) test_size=0).validate)
class Seq2SeqTrainedAttackerTest(absltest.TestCase):
def test_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object)
]),
vocab_size=3,
train_size=3,
test_size=2)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(
attack_input_data=attack_input,
test_fraction=0.25,
balance=False,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 3)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([2, 1], dtype=np.float32)])
]),
vocab_size=3,
train_size=3,
test_size=3)
privacy_report_metadata = PrivacyReportMetadata()
attacker_data = create_seq2seq_attacker_data(
attack_input_data=attack_input,
test_fraction=0.33,
balance=True,
privacy_report_metadata=privacy_report_metadata)
self.assertLen(attacker_data.features_train, 4)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
# Tests that fields of PrivacyReportMetadata are populated.
self.assertIsNotNone(privacy_report_metadata.loss_train)
self.assertIsNotNone(privacy_report_metadata.loss_test)
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence, def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
vocab_size): vocab_size):
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence, num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
@ -323,7 +173,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
max_tokens_in_sequence=5, max_tokens_in_sequence=5,
vocab_size=2)) vocab_size=2))
seq2seq_result = list(result.single_attack_results)[0] seq2seq_result = list(result.single_attack_results)[0]
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION) self.assertEqual(seq2seq_result.attack_type, AttackType.THRESHOLD_ATTACK)
def test_run_seq2seq_attack_calculates_correct_auc(self): def test_run_seq2seq_attack_calculates_correct_auc(self):
result = run_seq2seq_attack( result = run_seq2seq_attack(
@ -337,7 +187,7 @@ class RunSeq2SeqAttackTest(absltest.TestCase):
balance_attacker_training=False) balance_attacker_training=False)
seq2seq_result = list(result.single_attack_results)[0] seq2seq_result = list(result.single_attack_results)[0]
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2) seq2seq_result.roc_curve.get_auc(), 0.59, decimal=2)
def test_run_seq2seq_attack_calculates_correct_metadata(self): def test_run_seq2seq_attack_calculates_correct_metadata(self):
attack_input = Seq2SeqAttackInputData( attack_input = Seq2SeqAttackInputData(