forked from 626_privacy/tensorflow_privacy
Merge pull request #137 from amad-person:add_seq2seq_mia_attacks
PiperOrigin-RevId: 343047622
This commit is contained in:
commit
35a8096173
8 changed files with 1738 additions and 2 deletions
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2019 Congzheng Song
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
File diff suppressed because one or more lines are too long
|
@ -18,7 +18,7 @@ import enum
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, Union
|
from typing import Any, Iterable, Union, Iterator
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -378,6 +378,91 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_iterator(obj, obj_name):
|
||||||
|
"""Checks whether obj is a generator."""
|
||||||
|
if obj is not None and not isinstance(obj, Iterator):
|
||||||
|
raise ValueError('%s should be a generator.' % obj_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seq2SeqAttackInputData:
|
||||||
|
"""Input data for running an attack on seq2seq models.
|
||||||
|
|
||||||
|
This includes only the data, and not configuration.
|
||||||
|
"""
|
||||||
|
logits_train: Iterator[np.ndarray] = None
|
||||||
|
logits_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Contains ground-truth token indices for the target sequences.
|
||||||
|
labels_train: Iterator[np.ndarray] = None
|
||||||
|
labels_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Size of the target sequence vocabulary.
|
||||||
|
vocab_size: int = None
|
||||||
|
|
||||||
|
# Train, test size = number of batches in training, test set.
|
||||||
|
# These values need to be supplied by the user as logits, labels
|
||||||
|
# are lazy loaded for seq2seq models.
|
||||||
|
train_size: int = 0
|
||||||
|
test_size: int = 0
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validates the inputs."""
|
||||||
|
|
||||||
|
if (self.logits_train is None) != (self.logits_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'logits_train and logits_test should both be either set or unset')
|
||||||
|
|
||||||
|
if (self.labels_train is None) != (self.labels_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'labels_train and labels_test should both be either set or unset')
|
||||||
|
|
||||||
|
if self.logits_train is None or self.labels_train is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Labels, logits of training, test sets should all be set')
|
||||||
|
|
||||||
|
if (self.vocab_size is None or self.train_size is None or
|
||||||
|
self.test_size is None):
|
||||||
|
raise ValueError('vocab_size, train_size, test_size should all be set')
|
||||||
|
|
||||||
|
if self.vocab_size is not None and not int:
|
||||||
|
raise ValueError('vocab_size should be of integer type')
|
||||||
|
|
||||||
|
if self.train_size is not None and not int:
|
||||||
|
raise ValueError('train_size should be of integer type')
|
||||||
|
|
||||||
|
if self.test_size is not None and not int:
|
||||||
|
raise ValueError('test_size should be of integer type')
|
||||||
|
|
||||||
|
_is_iterator(self.logits_train, 'logits_train')
|
||||||
|
_is_iterator(self.logits_test, 'logits_test')
|
||||||
|
_is_iterator(self.labels_train, 'labels_train')
|
||||||
|
_is_iterator(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Return the shapes of variables that are not None."""
|
||||||
|
result = ['AttackInputData(']
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.train_size is not None:
|
||||||
|
result.append(
|
||||||
|
'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
|
||||||
|
(self.train_size, self.vocab_size))
|
||||||
|
result.append(
|
||||||
|
'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
|
||||||
|
self.train_size)
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.test_size is not None:
|
||||||
|
result.append(
|
||||||
|
'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
|
||||||
|
(self.test_size, self.vocab_size))
|
||||||
|
result.append(
|
||||||
|
'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
|
||||||
|
self.test_size)
|
||||||
|
|
||||||
|
result.append(')')
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RocCurve:
|
class RocCurve:
|
||||||
"""Represents ROC curve of a membership inference classifier."""
|
"""Represents ROC curve of a membership inference classifier."""
|
||||||
|
|
|
@ -27,6 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
|
@ -152,6 +153,75 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
probs_test=np.array([])).validate)
|
probs_test=np.array([])).validate)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||||
|
|
||||||
|
def test_validator(self):
|
||||||
|
valid_logits_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_logits_test = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_test = iter([np.array([]), np.array([])])
|
||||||
|
|
||||||
|
invalid_logits_train = []
|
||||||
|
invalid_logits_test = []
|
||||||
|
invalid_labels_train = []
|
||||||
|
invalid_labels_test = []
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
||||||
|
|
||||||
|
# Tests that both logits and labels must be set.
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
|
||||||
|
# Tests that vocab, train, test sizes must all be set.
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test).validate)
|
||||||
|
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=invalid_logits_train,
|
||||||
|
logits_test=invalid_logits_test,
|
||||||
|
labels_train=invalid_labels_train,
|
||||||
|
labels_test=invalid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_auc_random_classifier(self):
|
def test_auc_random_classifier(self):
|
||||||
|
@ -275,7 +345,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class AttackResultsTest(absltest.TestCase):
|
class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
perfect_classifier_result: SingleAttackResult
|
perfect_classifier_result: SingleAttackResult
|
||||||
random_classifier_result: SingleAttackResult
|
random_classifier_result: SingleAttackResult
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||||
PrivacyReportMetadata
|
PrivacyReportMetadata
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -170,6 +171,54 @@ def run_attacks(attack_input: AttackInputData,
|
||||||
privacy_report_metadata=privacy_report_metadata)
|
privacy_report_metadata=privacy_report_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
|
unused_report_metadata: PrivacyReportMetadata = None,
|
||||||
|
balance_attacker_training: bool = True) -> AttackResults:
|
||||||
|
"""Runs membership inference attacks on a seq2seq model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input: input data for running an attack
|
||||||
|
unused_report_metadata: the metadata of the model under attack.
|
||||||
|
balance_attacker_training: Whether the training and test sets for the
|
||||||
|
membership inference attacker should have a balanced (roughly equal)
|
||||||
|
number of samples from the training and test sets used to develop the
|
||||||
|
model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the attack result.
|
||||||
|
"""
|
||||||
|
attack_input.validate()
|
||||||
|
|
||||||
|
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
||||||
|
# record to determine membership. So only Logistic Regression is supported,
|
||||||
|
# as it makes the most sense for single-number features.
|
||||||
|
attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
||||||
|
prepared_attacker_data = models.create_seq2seq_attacker_data(
|
||||||
|
attack_input, balance=balance_attacker_training)
|
||||||
|
|
||||||
|
attacker.train_model(prepared_attacker_data.features_train,
|
||||||
|
prepared_attacker_data.is_training_labels_train)
|
||||||
|
|
||||||
|
# Run the attacker on (permuted) test examples.
|
||||||
|
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||||
|
|
||||||
|
# Generate ROC curves with predictions.
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||||
|
|
||||||
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
|
attack_results = [
|
||||||
|
SingleAttackResult(
|
||||||
|
slice_spec=SingleSliceSpec(),
|
||||||
|
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||||
|
roc_curve=roc_curve)
|
||||||
|
]
|
||||||
|
|
||||||
|
return AttackResults(single_attack_results=attack_results)
|
||||||
|
|
||||||
|
|
||||||
def _compute_missing_privacy_report_metadata(
|
def _compute_missing_privacy_report_metadata(
|
||||||
metadata: PrivacyReportMetadata,
|
metadata: PrivacyReportMetadata,
|
||||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||||
|
@ -34,6 +35,68 @@ def get_test_input(n_train, n_test):
|
||||||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||||
|
|
||||||
|
|
||||||
|
def get_seq2seq_test_input(n_train,
|
||||||
|
n_test,
|
||||||
|
max_seq_in_batch,
|
||||||
|
max_tokens_in_sequence,
|
||||||
|
vocab_size,
|
||||||
|
seed=None):
|
||||||
|
"""Returns example inputs for attacks on seq2seq models."""
|
||||||
|
if seed is not None:
|
||||||
|
np.random.seed(seed=seed)
|
||||||
|
|
||||||
|
logits_train, labels_train = [], []
|
||||||
|
for _ in range(n_train):
|
||||||
|
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||||
|
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||||
|
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||||
|
logits_train.append(batch_logits)
|
||||||
|
labels_train.append(batch_labels)
|
||||||
|
|
||||||
|
logits_test, labels_test = [], []
|
||||||
|
for _ in range(n_test):
|
||||||
|
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||||
|
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||||
|
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||||
|
logits_test.append(batch_logits)
|
||||||
|
labels_test.append(batch_labels)
|
||||||
|
|
||||||
|
return Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter(logits_train),
|
||||||
|
logits_test=iter(logits_test),
|
||||||
|
labels_train=iter(labels_train),
|
||||||
|
labels_test=iter(labels_test),
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
train_size=n_train,
|
||||||
|
test_size=n_test)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||||
|
vocab_size):
|
||||||
|
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
||||||
|
num_sequences) + 1
|
||||||
|
batch_logits, batch_labels = [], []
|
||||||
|
for num_tokens in num_tokens_in_sequence:
|
||||||
|
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
|
||||||
|
batch_logits.append(logits)
|
||||||
|
batch_labels.append(labels)
|
||||||
|
return np.array(
|
||||||
|
batch_logits, dtype=object), np.array(
|
||||||
|
batch_labels, dtype=object)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
|
||||||
|
sequence_logits = []
|
||||||
|
for _ in range(num_tokens):
|
||||||
|
token_logits = np.random.random(vocab_size)
|
||||||
|
token_logits /= token_logits.sum()
|
||||||
|
sequence_logits.append(token_logits)
|
||||||
|
sequence_labels = np.random.choice(vocab_size, num_tokens)
|
||||||
|
return np.array(
|
||||||
|
sequence_logits, dtype=np.float32), np.array(
|
||||||
|
sequence_labels, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
class RunAttacksTest(absltest.TestCase):
|
class RunAttacksTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_run_attacks_size(self):
|
def test_run_attacks_size(self):
|
||||||
|
@ -97,6 +160,42 @@ class RunAttacksTest(absltest.TestCase):
|
||||||
# If accuracy is already present, simply return it.
|
# If accuracy is already present, simply return it.
|
||||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_size(self):
|
||||||
|
result = mia.run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=10,
|
||||||
|
n_test=5,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=2))
|
||||||
|
|
||||||
|
self.assertLen(result.single_attack_results, 1)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_trained_sets_attack_type(self):
|
||||||
|
result = mia.run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=10,
|
||||||
|
n_test=5,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=2))
|
||||||
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
|
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||||
|
|
||||||
|
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
||||||
|
result = mia.run_seq2seq_attack(
|
||||||
|
get_seq2seq_test_input(
|
||||||
|
n_train=20,
|
||||||
|
n_test=10,
|
||||||
|
max_seq_in_batch=3,
|
||||||
|
max_tokens_in_sequence=5,
|
||||||
|
vocab_size=3,
|
||||||
|
seed=12345),
|
||||||
|
balance_attacker_training=False)
|
||||||
|
seq2seq_result = list(result.single_attack_results)[0]
|
||||||
|
np.testing.assert_almost_equal(
|
||||||
|
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
|
@ -15,8 +15,11 @@
|
||||||
# Lint as: python3
|
# Lint as: python3
|
||||||
"""Trained models for membership inference attacks."""
|
"""Trained models for membership inference attacks."""
|
||||||
|
|
||||||
|
from typing import Iterator, List
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy.stats import rankdata
|
||||||
from sklearn import ensemble
|
from sklearn import ensemble
|
||||||
from sklearn import linear_model
|
from sklearn import linear_model
|
||||||
from sklearn import model_selection
|
from sklearn import model_selection
|
||||||
|
@ -24,6 +27,7 @@ from sklearn import neighbors
|
||||||
from sklearn import neural_network
|
from sklearn import neural_network
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -110,6 +114,98 @@ def _column_stack(logits, loss):
|
||||||
return np.column_stack((logits, loss))
|
return np.column_stack((logits, loss))
|
||||||
|
|
||||||
|
|
||||||
|
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||||
|
test_fraction: float = 0.25,
|
||||||
|
balance: bool = True) -> AttackerData:
|
||||||
|
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
||||||
|
|
||||||
|
Uses logits and losses to generate ranks and performs a random train-test
|
||||||
|
split.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attack_input_data: Original Seq2SeqAttackInputData
|
||||||
|
test_fraction: Fraction of the dataset to include in the test split.
|
||||||
|
balance: Whether the training and test sets for the membership inference
|
||||||
|
attacker should have a balanced (roughly equal) number of samples from the
|
||||||
|
training and test sets used to develop the model under attack.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AttackerData.
|
||||||
|
"""
|
||||||
|
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
|
||||||
|
attack_input_data.labels_train)
|
||||||
|
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
|
||||||
|
attack_input_data.labels_test)
|
||||||
|
|
||||||
|
if balance:
|
||||||
|
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||||
|
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||||
|
min_size)
|
||||||
|
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||||
|
min_size)
|
||||||
|
|
||||||
|
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||||
|
|
||||||
|
# Reshape for classifying one-dimensional features
|
||||||
|
features_all = features_all.reshape(-1, 1)
|
||||||
|
|
||||||
|
labels_all = np.concatenate(
|
||||||
|
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||||
|
|
||||||
|
# Perform a train-test split
|
||||||
|
features_train, features_test, \
|
||||||
|
is_training_labels_train, is_training_labels_test = \
|
||||||
|
model_selection.train_test_split(
|
||||||
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
|
|
||||||
|
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||||
|
is_training_labels_test)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_average_ranks(logits: Iterator[np.ndarray],
|
||||||
|
labels: Iterator[np.ndarray]) -> np.ndarray:
|
||||||
|
"""Returns the average rank of tokens in a batch of sequences.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
||||||
|
num_sequences, num_tokens, vocab_size).
|
||||||
|
labels: Target labels for the seq2seq model, dim = (num_batches,
|
||||||
|
num_sequences, num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of average ranks, dim = (num_batches, 1).
|
||||||
|
Each average rank is calculated over ranks of tokens in sequences of a
|
||||||
|
particular batch.
|
||||||
|
"""
|
||||||
|
ranks = []
|
||||||
|
for batch_logits, batch_labels in zip(logits, labels):
|
||||||
|
batch_ranks = []
|
||||||
|
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||||
|
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||||
|
ranks.append(np.mean(batch_ranks))
|
||||||
|
|
||||||
|
return np.array(ranks)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||||
|
labels: np.ndarray) -> List[float]:
|
||||||
|
"""Returns ranks for a sequence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
|
||||||
|
labels: Target labels of a single sequence, dim = (num_tokens, 1).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
|
||||||
|
"""
|
||||||
|
sequence_ranks = []
|
||||||
|
for logit, label in zip(logits, labels.astype(int)):
|
||||||
|
rank = rankdata(-logit, method='min')[label] - 1.0
|
||||||
|
sequence_ranks.append(rank)
|
||||||
|
|
||||||
|
return sequence_ranks
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttacker:
|
class TrainedAttacker:
|
||||||
"""Base class for training attack models."""
|
"""Base class for training attack models."""
|
||||||
model = None
|
model = None
|
||||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttackerTest(absltest.TestCase):
|
class TrainedAttackerTest(absltest.TestCase):
|
||||||
|
@ -55,6 +56,66 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
expected = feature[:2] not in attack_input.logits_train
|
expected = feature[:2] not in attack_input.logits_train
|
||||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||||
|
|
||||||
|
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array(
|
||||||
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=3,
|
||||||
|
test_size=2)
|
||||||
|
attacker_data = models.create_seq2seq_attacker_data(
|
||||||
|
attack_input, 0.25, balance=False)
|
||||||
|
self.assertLen(attacker_data.features_train, 3)
|
||||||
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
def test_balanced_create_attacker_data_loss_and_logits(self):
|
def test_balanced_create_attacker_data_loss_and_logits(self):
|
||||||
attack_input = AttackInputData(
|
attack_input = AttackInputData(
|
||||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||||
|
@ -70,6 +131,71 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
expected = feature[:2] not in attack_input.logits_train
|
expected = feature[:2] not in attack_input.logits_train
|
||||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||||
|
|
||||||
|
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||||
|
attack_input = Seq2SeqAttackInputData(
|
||||||
|
logits_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||||
|
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array(
|
||||||
|
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||||
|
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
logits_test=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||||
|
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_train=iter([
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||||
|
np.array([
|
||||||
|
np.array([0, 1], dtype=np.float32),
|
||||||
|
np.array([1, 2], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object)
|
||||||
|
]),
|
||||||
|
labels_test=iter([
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||||
|
np.array([
|
||||||
|
np.array([2, 0], dtype=np.float32),
|
||||||
|
np.array([1], dtype=np.float32)
|
||||||
|
],
|
||||||
|
dtype=object),
|
||||||
|
np.array([np.array([2, 1], dtype=np.float32)])
|
||||||
|
]),
|
||||||
|
vocab_size=3,
|
||||||
|
train_size=3,
|
||||||
|
test_size=3)
|
||||||
|
attacker_data = models.create_seq2seq_attacker_data(
|
||||||
|
attack_input, 0.33, balance=True)
|
||||||
|
self.assertLen(attacker_data.features_train, 4)
|
||||||
|
self.assertLen(attacker_data.features_test, 2)
|
||||||
|
|
||||||
|
for _, feature in enumerate(attacker_data.features_train):
|
||||||
|
self.assertLen(feature, 1) # each feature has one average rank
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue