Merge pull request #144 from amad-person:refactor-seq2seq
PiperOrigin-RevId: 346307900
This commit is contained in:
commit
b208d9deec
9 changed files with 804 additions and 530 deletions
|
@ -1142,8 +1142,8 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia\n",
|
"from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData, \\\n",
|
||||||
"from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData\n",
|
" run_seq2seq_attack\n",
|
||||||
"import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n",
|
"import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n",
|
||||||
"\n",
|
"\n",
|
||||||
"attack_input = Seq2SeqAttackInputData(\n",
|
"attack_input = Seq2SeqAttackInputData(\n",
|
||||||
|
@ -1157,13 +1157,16 @@
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Run several attacks for different data slices\n",
|
"# Run several attacks for different data slices\n",
|
||||||
"attack_result = mia.run_seq2seq_attack(attack_input)\n",
|
"attack_result = run_seq2seq_attack(attack_input)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Plot the ROC curve of the best classifier\n",
|
"# Plot the ROC curve of the best classifier\n",
|
||||||
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
|
"fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Print a user-friendly summary of the attacks\n",
|
"# Print a user-friendly summary of the attacks\n",
|
||||||
"print(attack_result.summary())"
|
"print(attack_result.summary())\n",
|
||||||
|
"\n",
|
||||||
|
"# Print metadata of the target model\n",
|
||||||
|
"print(attack_result.privacy_report_metadata)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -18,7 +18,7 @@ import enum
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, Union, Iterator
|
from typing import Any, Iterable, Union
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -378,91 +378,6 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||||
|
|
||||||
|
|
||||||
def _is_iterator(obj, obj_name):
|
|
||||||
"""Checks whether obj is a generator."""
|
|
||||||
if obj is not None and not isinstance(obj, Iterator):
|
|
||||||
raise ValueError('%s should be a generator.' % obj_name)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Seq2SeqAttackInputData:
|
|
||||||
"""Input data for running an attack on seq2seq models.
|
|
||||||
|
|
||||||
This includes only the data, and not configuration.
|
|
||||||
"""
|
|
||||||
logits_train: Iterator[np.ndarray] = None
|
|
||||||
logits_test: Iterator[np.ndarray] = None
|
|
||||||
|
|
||||||
# Contains ground-truth token indices for the target sequences.
|
|
||||||
labels_train: Iterator[np.ndarray] = None
|
|
||||||
labels_test: Iterator[np.ndarray] = None
|
|
||||||
|
|
||||||
# Size of the target sequence vocabulary.
|
|
||||||
vocab_size: int = None
|
|
||||||
|
|
||||||
# Train, test size = number of batches in training, test set.
|
|
||||||
# These values need to be supplied by the user as logits, labels
|
|
||||||
# are lazy loaded for seq2seq models.
|
|
||||||
train_size: int = 0
|
|
||||||
test_size: int = 0
|
|
||||||
|
|
||||||
def validate(self):
|
|
||||||
"""Validates the inputs."""
|
|
||||||
|
|
||||||
if (self.logits_train is None) != (self.logits_test is None):
|
|
||||||
raise ValueError(
|
|
||||||
'logits_train and logits_test should both be either set or unset')
|
|
||||||
|
|
||||||
if (self.labels_train is None) != (self.labels_test is None):
|
|
||||||
raise ValueError(
|
|
||||||
'labels_train and labels_test should both be either set or unset')
|
|
||||||
|
|
||||||
if self.logits_train is None or self.labels_train is None:
|
|
||||||
raise ValueError(
|
|
||||||
'Labels, logits of training, test sets should all be set')
|
|
||||||
|
|
||||||
if (self.vocab_size is None or self.train_size is None or
|
|
||||||
self.test_size is None):
|
|
||||||
raise ValueError('vocab_size, train_size, test_size should all be set')
|
|
||||||
|
|
||||||
if self.vocab_size is not None and not int:
|
|
||||||
raise ValueError('vocab_size should be of integer type')
|
|
||||||
|
|
||||||
if self.train_size is not None and not int:
|
|
||||||
raise ValueError('train_size should be of integer type')
|
|
||||||
|
|
||||||
if self.test_size is not None and not int:
|
|
||||||
raise ValueError('test_size should be of integer type')
|
|
||||||
|
|
||||||
_is_iterator(self.logits_train, 'logits_train')
|
|
||||||
_is_iterator(self.logits_test, 'logits_test')
|
|
||||||
_is_iterator(self.labels_train, 'labels_train')
|
|
||||||
_is_iterator(self.labels_test, 'labels_test')
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
"""Return the shapes of variables that are not None."""
|
|
||||||
result = ['AttackInputData(']
|
|
||||||
|
|
||||||
if self.vocab_size is not None and self.train_size is not None:
|
|
||||||
result.append(
|
|
||||||
'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
|
|
||||||
(self.train_size, self.vocab_size))
|
|
||||||
result.append(
|
|
||||||
'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
|
|
||||||
self.train_size)
|
|
||||||
|
|
||||||
if self.vocab_size is not None and self.test_size is not None:
|
|
||||||
result.append(
|
|
||||||
'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
|
|
||||||
(self.test_size, self.vocab_size))
|
|
||||||
result.append(
|
|
||||||
'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
|
|
||||||
self.test_size)
|
|
||||||
|
|
||||||
result.append(')')
|
|
||||||
return '\n'.join(result)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RocCurve:
|
class RocCurve:
|
||||||
"""Represents ROC curve of a membership inference classifier."""
|
"""Represents ROC curve of a membership inference classifier."""
|
||||||
|
|
|
@ -27,7 +27,6 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
@ -153,75 +152,6 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
probs_test=np.array([])).validate)
|
probs_test=np.array([])).validate)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
|
||||||
|
|
||||||
def test_validator(self):
|
|
||||||
valid_logits_train = iter([np.array([]), np.array([])])
|
|
||||||
valid_logits_test = iter([np.array([]), np.array([])])
|
|
||||||
valid_labels_train = iter([np.array([]), np.array([])])
|
|
||||||
valid_labels_test = iter([np.array([]), np.array([])])
|
|
||||||
|
|
||||||
invalid_logits_train = []
|
|
||||||
invalid_logits_test = []
|
|
||||||
invalid_labels_train = []
|
|
||||||
invalid_labels_test = []
|
|
||||||
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
|
||||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
|
|
||||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
|
|
||||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
|
|
||||||
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
|
||||||
|
|
||||||
# Tests that both logits and labels must be set.
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(
|
|
||||||
logits_train=valid_logits_train,
|
|
||||||
logits_test=valid_logits_test,
|
|
||||||
vocab_size=0,
|
|
||||||
train_size=0,
|
|
||||||
test_size=0).validate)
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(
|
|
||||||
labels_train=valid_labels_train,
|
|
||||||
labels_test=valid_labels_test,
|
|
||||||
vocab_size=0,
|
|
||||||
train_size=0,
|
|
||||||
test_size=0).validate)
|
|
||||||
|
|
||||||
# Tests that vocab, train, test sizes must all be set.
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(
|
|
||||||
logits_train=valid_logits_train,
|
|
||||||
logits_test=valid_logits_test,
|
|
||||||
labels_train=valid_labels_train,
|
|
||||||
labels_test=valid_labels_test).validate)
|
|
||||||
|
|
||||||
self.assertRaises(
|
|
||||||
ValueError,
|
|
||||||
Seq2SeqAttackInputData(
|
|
||||||
logits_train=invalid_logits_train,
|
|
||||||
logits_test=invalid_logits_test,
|
|
||||||
labels_train=invalid_labels_train,
|
|
||||||
labels_test=invalid_labels_test,
|
|
||||||
vocab_size=0,
|
|
||||||
train_size=0,
|
|
||||||
test_size=0).validate)
|
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_auc_random_classifier(self):
|
def test_auc_random_classifier(self):
|
||||||
|
|
|
@ -30,7 +30,6 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
PrivacyReportMetadata
|
PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -173,54 +172,6 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
||||||
|
|
||||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
|
||||||
unused_report_metadata: PrivacyReportMetadata = None,
|
|
||||||
balance_attacker_training: bool = True) -> AttackResults:
|
|
||||||
"""Runs membership inference attacks on a seq2seq model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attack_input: input data for running an attack
|
|
||||||
unused_report_metadata: the metadata of the model under attack.
|
|
||||||
balance_attacker_training: Whether the training and test sets for the
|
|
||||||
membership inference attacker should have a balanced (roughly equal)
|
|
||||||
number of samples from the training and test sets used to develop the
|
|
||||||
model under attack.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the attack result.
|
|
||||||
"""
|
|
||||||
attack_input.validate()
|
|
||||||
|
|
||||||
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
|
||||||
# record to determine membership. So only Logistic Regression is supported,
|
|
||||||
# as it makes the most sense for single-number features.
|
|
||||||
attacker = models.LogisticRegressionAttacker()
|
|
||||||
|
|
||||||
prepared_attacker_data = models.create_seq2seq_attacker_data(
|
|
||||||
attack_input, balance=balance_attacker_training)
|
|
||||||
|
|
||||||
attacker.train_model(prepared_attacker_data.features_train,
|
|
||||||
prepared_attacker_data.is_training_labels_train)
|
|
||||||
|
|
||||||
# Run the attacker on (permuted) test examples.
|
|
||||||
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
|
||||||
|
|
||||||
# Generate ROC curves with predictions.
|
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
|
||||||
prepared_attacker_data.is_training_labels_test, predictions_test)
|
|
||||||
|
|
||||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
|
||||||
|
|
||||||
attack_results = [
|
|
||||||
SingleAttackResult(
|
|
||||||
slice_spec=SingleSliceSpec(),
|
|
||||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
|
||||||
roc_curve=roc_curve)
|
|
||||||
]
|
|
||||||
|
|
||||||
return AttackResults(single_attack_results=attack_results)
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_missing_privacy_report_metadata(
|
def _compute_missing_privacy_report_metadata(
|
||||||
metadata: PrivacyReportMetadata,
|
metadata: PrivacyReportMetadata,
|
||||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||||
|
|
|
@ -16,10 +16,10 @@
|
||||||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.utils."""
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -35,68 +35,6 @@ def get_test_input(n_train, n_test):
|
||||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||||
|
|
||||||
|
|
||||||
def get_seq2seq_test_input(n_train,
|
|
||||||
n_test,
|
|
||||||
max_seq_in_batch,
|
|
||||||
max_tokens_in_sequence,
|
|
||||||
vocab_size,
|
|
||||||
seed=None):
|
|
||||||
"""Returns example inputs for attacks on seq2seq models."""
|
|
||||||
if seed is not None:
|
|
||||||
np.random.seed(seed=seed)
|
|
||||||
|
|
||||||
logits_train, labels_train = [], []
|
|
||||||
for _ in range(n_train):
|
|
||||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
|
||||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
|
||||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
|
||||||
logits_train.append(batch_logits)
|
|
||||||
labels_train.append(batch_labels)
|
|
||||||
|
|
||||||
logits_test, labels_test = [], []
|
|
||||||
for _ in range(n_test):
|
|
||||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
|
||||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
|
||||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
|
||||||
logits_test.append(batch_logits)
|
|
||||||
labels_test.append(batch_labels)
|
|
||||||
|
|
||||||
return Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter(logits_train),
|
|
||||||
logits_test=iter(logits_test),
|
|
||||||
labels_train=iter(labels_train),
|
|
||||||
labels_test=iter(labels_test),
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
train_size=n_train,
|
|
||||||
test_size=n_test)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
|
||||||
vocab_size):
|
|
||||||
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
|
||||||
num_sequences) + 1
|
|
||||||
batch_logits, batch_labels = [], []
|
|
||||||
for num_tokens in num_tokens_in_sequence:
|
|
||||||
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
|
|
||||||
batch_logits.append(logits)
|
|
||||||
batch_labels.append(labels)
|
|
||||||
return np.array(
|
|
||||||
batch_logits, dtype=object), np.array(
|
|
||||||
batch_labels, dtype=object)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
|
|
||||||
sequence_logits = []
|
|
||||||
for _ in range(num_tokens):
|
|
||||||
token_logits = np.random.random(vocab_size)
|
|
||||||
token_logits /= token_logits.sum()
|
|
||||||
sequence_logits.append(token_logits)
|
|
||||||
sequence_labels = np.random.choice(vocab_size, num_tokens)
|
|
||||||
return np.array(
|
|
||||||
sequence_logits, dtype=np.float32), np.array(
|
|
||||||
sequence_labels, dtype=np.float32)
|
|
||||||
|
|
||||||
|
|
||||||
class RunAttacksTest(absltest.TestCase):
|
class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_run_attacks_size(self):
|
def test_run_attacks_size(self):
|
||||||
|
@ -160,42 +98,6 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
# If accuracy is already present, simply return it.
|
# If accuracy is already present, simply return it.
|
||||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||||
|
|
||||||
def test_run_seq2seq_attack_size(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=10,
|
|
||||||
n_test=5,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=2))
|
|
||||||
|
|
||||||
self.assertLen(result.single_attack_results, 1)
|
|
||||||
|
|
||||||
def test_run_seq2seq_attack_trained_sets_attack_type(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=10,
|
|
||||||
n_test=5,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=2))
|
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
|
||||||
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
|
||||||
|
|
||||||
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
|
||||||
result = mia.run_seq2seq_attack(
|
|
||||||
get_seq2seq_test_input(
|
|
||||||
n_train=20,
|
|
||||||
n_test=10,
|
|
||||||
max_seq_in_batch=3,
|
|
||||||
max_tokens_in_sequence=5,
|
|
||||||
vocab_size=3,
|
|
||||||
seed=12345),
|
|
||||||
balance_attacker_training=False)
|
|
||||||
seq2seq_result = list(result.single_attack_results)[0]
|
|
||||||
np.testing.assert_almost_equal(
|
|
||||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -15,11 +15,8 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Trained models for membership inference attacks."""
|
"""Trained models for membership inference attacks."""
|
||||||
|
|
||||||
from typing import Iterator, List
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.stats import rankdata
|
|
||||||
from sklearn import ensemble
|
from sklearn import ensemble
|
||||||
from sklearn import linear_model
|
from sklearn import linear_model
|
||||||
from sklearn import model_selection
|
from sklearn import model_selection
|
||||||
|
@ -27,7 +24,6 @@ from sklearn import neighbors
|
||||||
from sklearn import neural_network
|
from sklearn import neural_network
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -114,98 +110,6 @@ def _column_stack(logits, loss):
|
||||||
return np.column_stack((logits, loss))
|
return np.column_stack((logits, loss))
|
||||||
|
|
||||||
|
|
||||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
|
||||||
test_fraction: float = 0.25,
|
|
||||||
balance: bool = True) -> AttackerData:
|
|
||||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
|
||||||
|
|
||||||
Uses logits and losses to generate ranks and performs a random train-test
|
|
||||||
split.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
attack_input_data: Original Seq2SeqAttackInputData
|
|
||||||
test_fraction: Fraction of the dataset to include in the test split.
|
|
||||||
balance: Whether the training and test sets for the membership inference
|
|
||||||
attacker should have a balanced (roughly equal) number of samples from the
|
|
||||||
training and test sets used to develop the model under attack.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AttackerData.
|
|
||||||
"""
|
|
||||||
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
|
|
||||||
attack_input_data.labels_train)
|
|
||||||
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
|
|
||||||
attack_input_data.labels_test)
|
|
||||||
|
|
||||||
if balance:
|
|
||||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
|
||||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
|
||||||
min_size)
|
|
||||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
|
||||||
min_size)
|
|
||||||
|
|
||||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
|
||||||
|
|
||||||
# Reshape for classifying one-dimensional features
|
|
||||||
features_all = features_all.reshape(-1, 1)
|
|
||||||
|
|
||||||
labels_all = np.concatenate(
|
|
||||||
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
|
||||||
|
|
||||||
# Perform a train-test split
|
|
||||||
features_train, features_test, \
|
|
||||||
is_training_labels_train, is_training_labels_test = \
|
|
||||||
model_selection.train_test_split(
|
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
|
||||||
|
|
||||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
|
||||||
is_training_labels_test)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_average_ranks(logits: Iterator[np.ndarray],
|
|
||||||
labels: Iterator[np.ndarray]) -> np.ndarray:
|
|
||||||
"""Returns the average rank of tokens in a batch of sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
|
||||||
num_sequences, num_tokens, vocab_size).
|
|
||||||
labels: Target labels for the seq2seq model, dim = (num_batches,
|
|
||||||
num_sequences, num_tokens, 1).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array of average ranks, dim = (num_batches, 1).
|
|
||||||
Each average rank is calculated over ranks of tokens in sequences of a
|
|
||||||
particular batch.
|
|
||||||
"""
|
|
||||||
ranks = []
|
|
||||||
for batch_logits, batch_labels in zip(logits, labels):
|
|
||||||
batch_ranks = []
|
|
||||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
|
||||||
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
|
||||||
ranks.append(np.mean(batch_ranks))
|
|
||||||
|
|
||||||
return np.array(ranks)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_ranks_for_sequence(logits: np.ndarray,
|
|
||||||
labels: np.ndarray) -> List[float]:
|
|
||||||
"""Returns ranks for a sequence.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
|
|
||||||
labels: Target labels of a single sequence, dim = (num_tokens, 1).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
|
|
||||||
"""
|
|
||||||
sequence_ranks = []
|
|
||||||
for logit, label in zip(logits, labels.astype(int)):
|
|
||||||
rank = rankdata(-logit, method='min')[label] - 1.0
|
|
||||||
sequence_ranks.append(rank)
|
|
||||||
|
|
||||||
return sequence_ranks
|
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttacker:
|
class TrainedAttacker:
|
||||||
"""Base class for training attack models."""
|
"""Base class for training attack models."""
|
||||||
model = None
|
model = None
|
||||||
|
|
|
@ -19,7 +19,6 @@ import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttackerTest(absltest.TestCase):
|
class TrainedAttackerTest(absltest.TestCase):
|
||||||
|
@ -56,66 +55,6 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
expected = feature[:2] not in attack_input.logits_train
|
expected = feature[:2] not in attack_input.logits_train
|
||||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||||
|
|
||||||
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
|
||||||
attack_input = Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array(
|
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
logits_test=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([0, 1], dtype=np.float32),
|
|
||||||
np.array([1, 2], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_test=iter([
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
vocab_size=3,
|
|
||||||
train_size=3,
|
|
||||||
test_size=2)
|
|
||||||
attacker_data = models.create_seq2seq_attacker_data(
|
|
||||||
attack_input, 0.25, balance=False)
|
|
||||||
self.assertLen(attacker_data.features_train, 3)
|
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
|
||||||
|
|
||||||
def test_balanced_create_attacker_data_loss_and_logits(self):
|
def test_balanced_create_attacker_data_loss_and_logits(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
@ -131,71 +70,6 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
expected = feature[:2] not in attack_input.logits_train
|
expected = feature[:2] not in attack_input.logits_train
|
||||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||||
|
|
||||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
|
||||||
attack_input = Seq2SeqAttackInputData(
|
|
||||||
logits_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
|
||||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array(
|
|
||||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
|
||||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
logits_test=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
|
||||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_train=iter([
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
|
||||||
np.array([
|
|
||||||
np.array([0, 1], dtype=np.float32),
|
|
||||||
np.array([1, 2], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object)
|
|
||||||
]),
|
|
||||||
labels_test=iter([
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
|
||||||
np.array([
|
|
||||||
np.array([2, 0], dtype=np.float32),
|
|
||||||
np.array([1], dtype=np.float32)
|
|
||||||
],
|
|
||||||
dtype=object),
|
|
||||||
np.array([np.array([2, 1], dtype=np.float32)])
|
|
||||||
]),
|
|
||||||
vocab_size=3,
|
|
||||||
train_size=3,
|
|
||||||
test_size=3)
|
|
||||||
attacker_data = models.create_seq2seq_attacker_data(
|
|
||||||
attack_input, 0.33, balance=True)
|
|
||||||
self.assertLen(attacker_data.features_train, 4)
|
|
||||||
self.assertLen(attacker_data.features_test, 2)
|
|
||||||
|
|
||||||
for _, feature in enumerate(attacker_data.features_train):
|
|
||||||
self.assertLen(feature, 1) # each feature has one average rank
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -0,0 +1,370 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Code for membership inference attacks on seq2seq models.
|
||||||
|
|
||||||
|
Contains seq2seq specific logic for attack data structures, attack data
|
||||||
|
generation,
|
||||||
|
and the logistic regression membership inference attack.
|
||||||
|
"""
|
||||||
|
from typing import Iterator, List
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import numpy as np
|
||||||
|
from scipy.stats import rankdata
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn import model_selection
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.models import _sample_multidimensional_array
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.models import AttackerData
|
||||||
|
|
||||||
|
|
||||||
|
def _is_iterator(obj, obj_name):
|
||||||
|
"""Checks whether obj is a generator."""
|
||||||
|
if obj is not None and not isinstance(obj, Iterator):
|
||||||
|
raise ValueError('%s should be a generator.' % obj_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seq2SeqAttackInputData:
|
||||||
|
"""Input data for running an attack on seq2seq models.
|
||||||
|
|
||||||
|
This includes only the data, and not configuration.
|
||||||
|
"""
|
||||||
|
logits_train: Iterator[np.ndarray] = None
|
||||||
|
logits_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Contains ground-truth token indices for the target sequences.
|
||||||
|
labels_train: Iterator[np.ndarray] = None
|
||||||
|
labels_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Size of the target sequence vocabulary.
|
||||||
|
vocab_size: int = None
|
||||||
|
|
||||||
|
# Train, test size = number of batches in training, test set.
|
||||||
|
# These values need to be supplied by the user as logits, labels
|
||||||
|
# are lazy loaded for seq2seq models.
|
||||||
|
train_size: int = 0
|
||||||
|
test_size: int = 0
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validates the inputs."""
|
||||||
|
|
||||||
|
if (self.logits_train is None) != (self.logits_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'logits_train and logits_test should both be either set or unset')
|
||||||
|
|
||||||
|
if (self.labels_train is None) != (self.labels_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'labels_train and labels_test should both be either set or unset')
|
||||||
|
|
||||||
|
if self.logits_train is None or self.labels_train is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Labels, logits of training, test sets should all be set')
|
||||||
|
|
||||||
|
if (self.vocab_size is None or self.train_size is None or
|
||||||
|
self.test_size is None):
|
||||||
|
raise ValueError('vocab_size, train_size, test_size should all be set')
|
||||||
|
|
||||||
|
if self.vocab_size is not None and not int:
|
||||||
|
raise ValueError('vocab_size should be of integer type')
|
||||||
|
|
||||||
|
if self.train_size is not None and not int:
|
||||||
|
raise ValueError('train_size should be of integer type')
|
||||||
|
|
||||||
|
if self.test_size is not None and not int:
|
||||||
|
raise ValueError('test_size should be of integer type')
|
||||||
|
|
||||||
|
_is_iterator(self.logits_train, 'logits_train')
|
||||||
|
_is_iterator(self.logits_test, 'logits_test')
|
||||||
|
_is_iterator(self.labels_train, 'labels_train')
|
||||||
|
_is_iterator(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Returns the shapes of variables that are not None."""
|
||||||
|
result = ['AttackInputData(']
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.train_size is not None:
|
||||||
|
result.append(
|
||||||
|
'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
|
||||||
|
(self.train_size, self.vocab_size))
|
||||||
|
result.append(
|
||||||
|
'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
|
||||||
|
self.train_size)
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.test_size is not None:
|
||||||
|
result.append(
|
||||||
|
'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
|
||||||
|
(self.test_size, self.vocab_size))
|
||||||
|
result.append(
|
||||||
|
'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
|
||||||
|
self.test_size)
|
||||||
|
|
||||||
|
result.append(')')
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_attack_features_and_metadata(
|
||||||
|
logits: Iterator[np.ndarray],
|
||||||
|
labels: Iterator[np.ndarray]) -> (np.ndarray, float, float):
|
||||||
|
"""Returns the average rank of tokens per batch of sequences and the loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
||||||
|
num_sequences, num_tokens, vocab_size).
|
||||||
|
labels: Target labels for the seq2seq model, dim = (num_batches,
|
||||||
|
num_sequences, num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
1. An array of average ranks, dim = (num_batches, 1).
|
||||||
|
Each average rank is calculated over ranks of tokens in sequences of a
|
||||||
|
particular batch.
|
||||||
|
2. Loss computed over all logits and labels.
|
||||||
|
3. Accuracy computed over all logits and labels.
|
||||||
|
"""
|
||||||
|
ranks = []
|
||||||
|
loss = 0.0
|
||||||
|
dataset_length = 0.0
|
||||||
|
correct_preds = 0
|
||||||
|
total_preds = 0
|
||||||
|
for batch_logits, batch_labels in zip(logits, labels):
|
||||||
|
# Compute average rank for the current batch.
|
||||||
|
batch_ranks = _get_batch_ranks(batch_logits, batch_labels)
|
||||||
|
ranks.append(np.mean(batch_ranks))
|
||||||
|
|
||||||
|
# Update overall loss metrics with metrics of the current batch.
|
||||||
|
batch_loss, batch_length = _get_batch_loss_metrics(batch_logits,
|
||||||
|
batch_labels)
|
||||||
|
loss += batch_loss
|
||||||
|
dataset_length += batch_length
|
||||||
|
|
||||||
|
# Update overall accuracy metrics with metrics of the current batch.
|
||||||
|
batch_correct_preds, batch_total_preds = _get_batch_accuracy_metrics(
|
||||||
|
batch_logits, batch_labels)
|
||||||
|
correct_preds += batch_correct_preds
|
||||||
|
total_preds += batch_total_preds
|
||||||
|
|
||||||
|
# Compute loss and accuracy for the dataset.
|
||||||
|
loss = loss / dataset_length
|
||||||
|
accuracy = correct_preds / total_preds
|
||||||
|
|
||||||
|
return np.array(ranks), loss, accuracy
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_ranks(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray) -> np.ndarray:
|
||||||
|
"""Returns the ranks of tokens in a batch of sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of ranks of tokens in a batch of sequences, dim = (num_sequences,
|
||||||
|
num_tokens, 1)
|
||||||
|
"""
|
||||||
|
batch_ranks = []
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||||
|
|
||||||
|
return np.array(batch_ranks)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||||
|
labels: np.ndarray) -> List[float]:
|
||||||
|
"""Returns ranks for a sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
|
||||||
|
labels: Target labels of a single sequence, dim = (num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
|
||||||
|
"""
|
||||||
|
sequence_ranks = []
|
||||||
|
for logit, label in zip(logits, labels.astype(int)):
|
||||||
|
rank = rankdata(-logit, method='min')[label] - 1.0
|
||||||
|
sequence_ranks.append(rank)
|
||||||
|
|
||||||
|
return sequence_ranks
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_loss_metrics(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray) -> (float, int):
|
||||||
|
"""Returns the loss, number of sequences for a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
"""
|
||||||
|
batch_loss = 0.0
|
||||||
|
batch_length = len(batch_logits)
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
sequence_loss = tf.losses.sparse_categorical_crossentropy(
|
||||||
|
tf.keras.backend.constant(sequence_labels),
|
||||||
|
tf.keras.backend.constant(sequence_logits),
|
||||||
|
from_logits=True)
|
||||||
|
batch_loss += sequence_loss.numpy().sum()
|
||||||
|
|
||||||
|
return batch_loss / batch_length, batch_length
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
||||||
|
batch_labels: np.ndarray) -> (float, float):
|
||||||
|
"""Returns the number of correct predictions, total number of predictions for a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_logits: Logits returned by a seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, vocab_size).
|
||||||
|
batch_labels: Target labels for the seq2seq model, dim = (num_sequences,
|
||||||
|
num_tokens, 1).
|
||||||
|
"""
|
||||||
|
batch_correct_preds = 0.0
|
||||||
|
batch_total_preds = 0.0
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
preds = tf.metrics.sparse_categorical_accuracy(
|
||||||
|
tf.keras.backend.constant(sequence_labels),
|
||||||
|
tf.keras.backend.constant(sequence_logits))
|
||||||
|
batch_correct_preds += preds.numpy().sum()
|
||||||
|
batch_total_preds += len(sequence_labels)
|
||||||
|
|
||||||
|
return batch_correct_preds, batch_total_preds
|
||||||
|
|
||||||
|
|
||||||
|
def create_seq2seq_attacker_data(
|
||||||
|
attack_input_data: Seq2SeqAttackInputData,
|
||||||
|
test_fraction: float = 0.25,
|
||||||
|
balance: bool = True,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata = PrivacyReportMetadata()
|
||||||
|
) -> AttackerData:
|
||||||
|
"""Prepares Seq2SeqAttackInputData to train ML attackers.
|
||||||
|
|
||||||
|
Uses logits and losses to generate ranks and performs a random train-test
|
||||||
|
split.
|
||||||
|
|
||||||
|
Also computes metadata (loss, accuracy) for the model under attack
|
||||||
|
and populates respective fields of PrivacyReportMetadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input_data: Original Seq2SeqAttackInputData
|
||||||
|
test_fraction: Fraction of the dataset to include in the test split.
|
||||||
|
balance: Whether the training and test sets for the membership inference
|
||||||
|
attacker should have a balanced (roughly equal) number of samples from the
|
||||||
|
training and test sets used to develop the model under attack.
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AttackerData.
|
||||||
|
"""
|
||||||
|
attack_input_train, loss_train, accuracy_train = _get_attack_features_and_metadata(
|
||||||
|
attack_input_data.logits_train, attack_input_data.labels_train)
|
||||||
|
attack_input_test, loss_test, accuracy_test = _get_attack_features_and_metadata(
|
||||||
|
attack_input_data.logits_test, attack_input_data.labels_test)
|
||||||
|
|
||||||
|
if balance:
|
||||||
|
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||||
|
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||||
|
min_size)
|
||||||
|
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||||
|
min_size)
|
||||||
|
|
||||||
|
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||||
|
|
||||||
|
# Reshape for classifying one-dimensional features
|
||||||
|
features_all = features_all.reshape(-1, 1)
|
||||||
|
|
||||||
|
labels_all = np.concatenate(
|
||||||
|
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||||
|
|
||||||
|
# Perform a train-test split
|
||||||
|
features_train, features_test, \
|
||||||
|
is_training_labels_train, is_training_labels_test = \
|
||||||
|
model_selection.train_test_split(
|
||||||
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
|
|
||||||
|
# Populate accuracy, loss fields in privacy report metadata
|
||||||
|
privacy_report_metadata.loss_train = loss_train
|
||||||
|
privacy_report_metadata.loss_test = loss_test
|
||||||
|
privacy_report_metadata.accuracy_train = accuracy_train
|
||||||
|
privacy_report_metadata.accuracy_test = accuracy_test
|
||||||
|
|
||||||
|
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||||
|
is_training_labels_test)
|
||||||
|
|
||||||
|
|
||||||
|
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
|
privacy_report_metadata: PrivacyReportMetadata = None,
|
||||||
|
balance_attacker_training: bool = True) -> AttackResults:
|
||||||
|
"""Runs membership inference attacks on a seq2seq model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input: input data for running an attack
|
||||||
|
privacy_report_metadata: the metadata of the model under attack.
|
||||||
|
balance_attacker_training: Whether the training and test sets for the
|
||||||
|
membership inference attacker should have a balanced (roughly equal)
|
||||||
|
number of samples from the training and test sets used to develop the
|
||||||
|
model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the attack result.
|
||||||
|
"""
|
||||||
|
attack_input.validate()
|
||||||
|
|
||||||
|
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
||||||
|
# record to determine membership. So only Logistic Regression is supported,
|
||||||
|
# as it makes the most sense for single-number features.
|
||||||
|
attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
||||||
|
# Create attacker data and populate fields of privacy_report_metadata
|
||||||
|
privacy_report_metadata = privacy_report_metadata or PrivacyReportMetadata()
|
||||||
|
prepared_attacker_data = create_seq2seq_attacker_data(
|
||||||
|
attack_input_data=attack_input,
|
||||||
|
balance=balance_attacker_training,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
||||||
|
attacker.train_model(prepared_attacker_data.features_train,
|
||||||
|
prepared_attacker_data.is_training_labels_train)
|
||||||
|
|
||||||
|
# Run the attacker on (permuted) test examples.
|
||||||
|
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||||
|
|
||||||
|
# Generate ROC curves with predictions.
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||||
|
|
||||||
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
attack_results = [
|
||||||
|
SingleAttackResult(
|
||||||
|
slice_spec=SingleSliceSpec(),
|
||||||
|
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||||
|
roc_curve=roc_curve)
|
||||||
|
]
|
||||||
|
|
||||||
|
return AttackResults(
|
||||||
|
single_attack_results=attack_results,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
|
@ -0,0 +1,425 @@
|
||||||
|
# Copyright 2020, The TensorFlow Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Lint as: python3
|
||||||
|
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia."""
|
||||||
|
from absl.testing import absltest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import create_seq2seq_attacker_data
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import run_seq2seq_attack
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.seq2seq_mia import Seq2SeqAttackInputData
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_validator(self):
|
||||||
|
valid_logits_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_logits_test = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_test = iter([np.array([]), np.array([])])
|
||||||
|
|
||||||
|
invalid_logits_train = []
|
||||||
|
invalid_logits_test = []
|
||||||
|
invalid_labels_train = []
|
||||||
|
invalid_labels_test = []
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
||||||
|
|
||||||
|
# Tests that both logits and labels must be set.
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
|
||||||
|
# Tests that vocab, train, test sizes must all be set.
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test).validate)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=invalid_logits_train,
|
||||||
|
logits_test=invalid_logits_test,
|
||||||
|
labels_train=invalid_labels_train,
|
||||||
|
labels_test=invalid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqTrainedAttackerTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array(
|
||||||
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=3,
|
||||||
|
test_size=2)
|
||||||
|
privacy_report_metadata = PrivacyReportMetadata()
|
||||||
|
attacker_data = create_seq2seq_attacker_data(
|
||||||
|
attack_input_data=attack_input,
|
||||||
|
test_fraction=0.25,
|
||||||
|
balance=False,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
self.assertLen(attacker_data.features_train, 3)
|
||||||
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
|
# Tests that fields of PrivacyReportMetadata are populated.
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||||
|
|
||||||
|
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array(
|
||||||
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)])
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=3,
|
||||||
|
test_size=3)
|
||||||
|
privacy_report_metadata = PrivacyReportMetadata()
|
||||||
|
attacker_data = create_seq2seq_attacker_data(
|
||||||
|
attack_input_data=attack_input,
|
||||||
|
test_fraction=0.33,
|
||||||
|
balance=True,
|
||||||
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
self.assertLen(attacker_data.features_train, 4)
|
||||||
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
|
# Tests that fields of PrivacyReportMetadata are populated.
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.loss_test)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_train)
|
||||||
|
self.assertIsNotNone(privacy_report_metadata.accuracy_test)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||||
|
vocab_size):
|
||||||
|
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
||||||
|
num_sequences) + 1
|
||||||
|
batch_logits, batch_labels = [], []
|
||||||
|
for num_tokens in num_tokens_in_sequence:
|
||||||
|
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
batch_labels.append(labels)
|
||||||
|
return np.array(
|
||||||
|
batch_logits, dtype=object), np.array(
|
||||||
|
batch_labels, dtype=object)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
|
||||||
|
sequence_logits = []
|
||||||
|
for _ in range(num_tokens):
|
||||||
|
token_logits = np.random.random(vocab_size)
|
||||||
|
token_logits /= token_logits.sum()
|
||||||
|
sequence_logits.append(token_logits)
|
||||||
|
sequence_labels = np.random.choice(vocab_size, num_tokens)
|
||||||
|
return np.array(
|
||||||
|
sequence_logits, dtype=np.float32), np.array(
|
||||||
|
sequence_labels, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq2seq_test_input(n_train,
|
||||||
|
n_test,
|
||||||
|
max_seq_in_batch,
|
||||||
|
max_tokens_in_sequence,
|
||||||
|
vocab_size,
|
||||||
|
seed=None):
|
||||||
|
"""Returns example inputs for attacks on seq2seq models."""
|
||||||
|
if seed is not None:
|
||||||
|
np.random.seed(seed=seed)
|
||||||
|
|
||||||
|
logits_train, labels_train = [], []
|
||||||
|
for _ in range(n_train):
|
||||||
|
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||||
|
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||||
|
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||||
|
logits_train.append(batch_logits)
|
||||||
|
labels_train.append(batch_labels)
|
||||||
|
|
||||||
|
logits_test, labels_test = [], []
|
||||||
|
for _ in range(n_test):
|
||||||
|
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||||
|
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||||
|
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||||
|
logits_test.append(batch_logits)
|
||||||
|
labels_test.append(batch_labels)
|
||||||
|
|
||||||
|
return Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter(logits_train),
|
||||||
|
logits_test=iter(logits_test),
|
||||||
|
labels_train=iter(labels_train),
|
||||||
|
labels_test=iter(labels_test),
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
train_size=n_train,
|
||||||
|
test_size=n_test)
|
||||||
|
|
||||||
|
|
||||||
|
class RunSeq2SeqAttackTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_size(self):
|
||||||
|
result = run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=10,
|
||||||
|
n_test=5,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=2))
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_trained_sets_attack_type(self):
|
||||||
|
result = run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=10,
|
||||||
|
n_test=5,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=2))
|
||||||
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
|
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
||||||
|
result = run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=20,
|
||||||
|
n_test=10,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=3,
|
||||||
|
seed=12345),
|
||||||
|
balance_attacker_training=False)
|
||||||
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_calculates_correct_metadata(self):
|
||||||
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array(
|
||||||
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 0], dtype=np.float32),
|
||||||
|
np.array([0, 1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=4,
|
||||||
|
test_size=4)
|
||||||
|
result = run_seq2seq_attack(attack_input, balance_attacker_training=False)
|
||||||
|
metadata = result.privacy_report_metadata
|
||||||
|
np.testing.assert_almost_equal(metadata.loss_train, 0.91, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.loss_test, 1.58, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.accuracy_train, 0.77, decimal=2)
|
||||||
|
np.testing.assert_almost_equal(metadata.accuracy_test, 0.67, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
absltest.main()
|
Loading…
Reference in a new issue