Merge pull request #137 from amad-person:add_seq2seq_mia_attacks
PiperOrigin-RevId: 343047622
This commit is contained in:
commit
35a8096173
8 changed files with 1738 additions and 2 deletions
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2019 Congzheng Song
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
File diff suppressed because one or more lines are too long
|
@ -18,7 +18,7 @@ import enum
|
|||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
from typing import Any, Iterable, Union, Iterator
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
@ -378,6 +378,91 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
|||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||
|
||||
|
||||
def _is_iterator(obj, obj_name):
|
||||
"""Checks whether obj is a generator."""
|
||||
if obj is not None and not isinstance(obj, Iterator):
|
||||
raise ValueError('%s should be a generator.' % obj_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqAttackInputData:
|
||||
"""Input data for running an attack on seq2seq models.
|
||||
|
||||
This includes only the data, and not configuration.
|
||||
"""
|
||||
logits_train: Iterator[np.ndarray] = None
|
||||
logits_test: Iterator[np.ndarray] = None
|
||||
|
||||
# Contains ground-truth token indices for the target sequences.
|
||||
labels_train: Iterator[np.ndarray] = None
|
||||
labels_test: Iterator[np.ndarray] = None
|
||||
|
||||
# Size of the target sequence vocabulary.
|
||||
vocab_size: int = None
|
||||
|
||||
# Train, test size = number of batches in training, test set.
|
||||
# These values need to be supplied by the user as logits, labels
|
||||
# are lazy loaded for seq2seq models.
|
||||
train_size: int = 0
|
||||
test_size: int = 0
|
||||
|
||||
def validate(self):
|
||||
"""Validates the inputs."""
|
||||
|
||||
if (self.logits_train is None) != (self.logits_test is None):
|
||||
raise ValueError(
|
||||
'logits_train and logits_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None) != (self.labels_test is None):
|
||||
raise ValueError(
|
||||
'labels_train and labels_test should both be either set or unset')
|
||||
|
||||
if self.logits_train is None or self.labels_train is None:
|
||||
raise ValueError(
|
||||
'Labels, logits of training, test sets should all be set')
|
||||
|
||||
if (self.vocab_size is None or self.train_size is None or
|
||||
self.test_size is None):
|
||||
raise ValueError('vocab_size, train_size, test_size should all be set')
|
||||
|
||||
if self.vocab_size is not None and not int:
|
||||
raise ValueError('vocab_size should be of integer type')
|
||||
|
||||
if self.train_size is not None and not int:
|
||||
raise ValueError('train_size should be of integer type')
|
||||
|
||||
if self.test_size is not None and not int:
|
||||
raise ValueError('test_size should be of integer type')
|
||||
|
||||
_is_iterator(self.logits_train, 'logits_train')
|
||||
_is_iterator(self.logits_test, 'logits_test')
|
||||
_is_iterator(self.labels_train, 'labels_train')
|
||||
_is_iterator(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
result = ['AttackInputData(']
|
||||
|
||||
if self.vocab_size is not None and self.train_size is not None:
|
||||
result.append(
|
||||
'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
|
||||
(self.train_size, self.vocab_size))
|
||||
result.append(
|
||||
'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
|
||||
self.train_size)
|
||||
|
||||
if self.vocab_size is not None and self.test_size is not None:
|
||||
result.append(
|
||||
'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
|
||||
(self.test_size, self.vocab_size))
|
||||
result.append(
|
||||
'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
|
||||
self.test_size)
|
||||
|
||||
result.append(')')
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocCurve:
|
||||
"""Represents ROC curve of a membership inference classifier."""
|
||||
|
|
|
@ -27,6 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
|
@ -152,6 +153,75 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
probs_test=np.array([])).validate)
|
||||
|
||||
|
||||
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||
|
||||
def test_validator(self):
|
||||
valid_logits_train = iter([np.array([]), np.array([])])
|
||||
valid_logits_test = iter([np.array([]), np.array([])])
|
||||
valid_labels_train = iter([np.array([]), np.array([])])
|
||||
valid_labels_test = iter([np.array([]), np.array([])])
|
||||
|
||||
invalid_logits_train = []
|
||||
invalid_logits_test = []
|
||||
invalid_labels_train = []
|
||||
invalid_labels_test = []
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
|
||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
|
||||
self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
|
||||
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
||||
|
||||
# Tests that both logits and labels must be set.
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=valid_logits_train,
|
||||
logits_test=valid_logits_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0).validate)
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
labels_train=valid_labels_train,
|
||||
labels_test=valid_labels_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0).validate)
|
||||
|
||||
# Tests that vocab, train, test sizes must all be set.
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=valid_logits_train,
|
||||
logits_test=valid_logits_test,
|
||||
labels_train=valid_labels_train,
|
||||
labels_test=valid_labels_test).validate)
|
||||
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=invalid_logits_train,
|
||||
logits_test=invalid_logits_test,
|
||||
labels_train=invalid_labels_train,
|
||||
labels_test=invalid_labels_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0).validate)
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
||||
def test_auc_random_classifier(self):
|
||||
|
@ -275,7 +345,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
|
|||
|
||||
|
||||
class AttackResultsTest(absltest.TestCase):
|
||||
|
||||
perfect_classifier_result: SingleAttackResult
|
||||
random_classifier_result: SingleAttackResult
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
|
|||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
|
||||
PrivacyReportMetadata
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
@ -170,6 +171,54 @@ def run_attacks(attack_input: AttackInputData,
|
|||
privacy_report_metadata=privacy_report_metadata)
|
||||
|
||||
|
||||
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||
unused_report_metadata: PrivacyReportMetadata = None,
|
||||
balance_attacker_training: bool = True) -> AttackResults:
|
||||
"""Runs membership inference attacks on a seq2seq model.
|
||||
|
||||
Args:
|
||||
attack_input: input data for running an attack
|
||||
unused_report_metadata: the metadata of the model under attack.
|
||||
balance_attacker_training: Whether the training and test sets for the
|
||||
membership inference attacker should have a balanced (roughly equal)
|
||||
number of samples from the training and test sets used to develop the
|
||||
model under attack.
|
||||
|
||||
Returns:
|
||||
the attack result.
|
||||
"""
|
||||
attack_input.validate()
|
||||
|
||||
# The attacker uses the average rank (a single number) of a seq2seq dataset
|
||||
# record to determine membership. So only Logistic Regression is supported,
|
||||
# as it makes the most sense for single-number features.
|
||||
attacker = models.LogisticRegressionAttacker()
|
||||
|
||||
prepared_attacker_data = models.create_seq2seq_attacker_data(
|
||||
attack_input, balance=balance_attacker_training)
|
||||
|
||||
attacker.train_model(prepared_attacker_data.features_train,
|
||||
prepared_attacker_data.is_training_labels_train)
|
||||
|
||||
# Run the attacker on (permuted) test examples.
|
||||
predictions_test = attacker.predict(prepared_attacker_data.features_test)
|
||||
|
||||
# Generate ROC curves with predictions.
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
prepared_attacker_data.is_training_labels_test, predictions_test)
|
||||
|
||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||
|
||||
attack_results = [
|
||||
SingleAttackResult(
|
||||
slice_spec=SingleSliceSpec(),
|
||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||
roc_curve=roc_curve)
|
||||
]
|
||||
|
||||
return AttackResults(single_attack_results=attack_results)
|
||||
|
||||
|
||||
def _compute_missing_privacy_report_metadata(
|
||||
metadata: PrivacyReportMetadata,
|
||||
attack_input: AttackInputData) -> PrivacyReportMetadata:
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
|
||||
|
@ -34,6 +35,68 @@ def get_test_input(n_train, n_test):
|
|||
labels_test=np.array([i % 5 for i in range(n_test)]))
|
||||
|
||||
|
||||
def get_seq2seq_test_input(n_train,
|
||||
n_test,
|
||||
max_seq_in_batch,
|
||||
max_tokens_in_sequence,
|
||||
vocab_size,
|
||||
seed=None):
|
||||
"""Returns example inputs for attacks on seq2seq models."""
|
||||
if seed is not None:
|
||||
np.random.seed(seed=seed)
|
||||
|
||||
logits_train, labels_train = [], []
|
||||
for _ in range(n_train):
|
||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||
logits_train.append(batch_logits)
|
||||
labels_train.append(batch_labels)
|
||||
|
||||
logits_test, labels_test = [], []
|
||||
for _ in range(n_test):
|
||||
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
|
||||
batch_logits, batch_labels = _get_batch_logits_and_labels(
|
||||
num_sequences, max_tokens_in_sequence, vocab_size)
|
||||
logits_test.append(batch_logits)
|
||||
labels_test.append(batch_labels)
|
||||
|
||||
return Seq2SeqAttackInputData(
|
||||
logits_train=iter(logits_train),
|
||||
logits_test=iter(logits_test),
|
||||
labels_train=iter(labels_train),
|
||||
labels_test=iter(labels_test),
|
||||
vocab_size=vocab_size,
|
||||
train_size=n_train,
|
||||
test_size=n_test)
|
||||
|
||||
|
||||
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
|
||||
vocab_size):
|
||||
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
|
||||
num_sequences) + 1
|
||||
batch_logits, batch_labels = [], []
|
||||
for num_tokens in num_tokens_in_sequence:
|
||||
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
|
||||
batch_logits.append(logits)
|
||||
batch_labels.append(labels)
|
||||
return np.array(
|
||||
batch_logits, dtype=object), np.array(
|
||||
batch_labels, dtype=object)
|
||||
|
||||
|
||||
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
|
||||
sequence_logits = []
|
||||
for _ in range(num_tokens):
|
||||
token_logits = np.random.random(vocab_size)
|
||||
token_logits /= token_logits.sum()
|
||||
sequence_logits.append(token_logits)
|
||||
sequence_labels = np.random.choice(vocab_size, num_tokens)
|
||||
return np.array(
|
||||
sequence_logits, dtype=np.float32), np.array(
|
||||
sequence_labels, dtype=np.float32)
|
||||
|
||||
|
||||
class RunAttacksTest(absltest.TestCase):
|
||||
|
||||
def test_run_attacks_size(self):
|
||||
|
@ -97,6 +160,42 @@ class RunAttacksTest(absltest.TestCase):
|
|||
# If accuracy is already present, simply return it.
|
||||
self.assertIsNone(mia._get_accuracy(None, labels))
|
||||
|
||||
def test_run_seq2seq_attack_size(self):
|
||||
result = mia.run_seq2seq_attack(
|
||||
get_seq2seq_test_input(
|
||||
n_train=10,
|
||||
n_test=5,
|
||||
max_seq_in_batch=3,
|
||||
max_tokens_in_sequence=5,
|
||||
vocab_size=2))
|
||||
|
||||
self.assertLen(result.single_attack_results, 1)
|
||||
|
||||
def test_run_seq2seq_attack_trained_sets_attack_type(self):
|
||||
result = mia.run_seq2seq_attack(
|
||||
get_seq2seq_test_input(
|
||||
n_train=10,
|
||||
n_test=5,
|
||||
max_seq_in_batch=3,
|
||||
max_tokens_in_sequence=5,
|
||||
vocab_size=2))
|
||||
seq2seq_result = list(result.single_attack_results)[0]
|
||||
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
|
||||
|
||||
def test_run_seq2seq_attack_calculates_correct_auc(self):
|
||||
result = mia.run_seq2seq_attack(
|
||||
get_seq2seq_test_input(
|
||||
n_train=20,
|
||||
n_test=10,
|
||||
max_seq_in_batch=3,
|
||||
max_tokens_in_sequence=5,
|
||||
vocab_size=3,
|
||||
seed=12345),
|
||||
balance_attacker_training=False)
|
||||
seq2seq_result = list(result.single_attack_results)[0]
|
||||
np.testing.assert_almost_equal(
|
||||
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -15,8 +15,11 @@
|
|||
# Lint as: python3
|
||||
"""Trained models for membership inference attacks."""
|
||||
|
||||
from typing import Iterator, List
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from scipy.stats import rankdata
|
||||
from sklearn import ensemble
|
||||
from sklearn import linear_model
|
||||
from sklearn import model_selection
|
||||
|
@ -24,6 +27,7 @@ from sklearn import neighbors
|
|||
from sklearn import neural_network
|
||||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -110,6 +114,98 @@ def _column_stack(logits, loss):
|
|||
return np.column_stack((logits, loss))
|
||||
|
||||
|
||||
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
|
||||
test_fraction: float = 0.25,
|
||||
balance: bool = True) -> AttackerData:
|
||||
"""Prepare Seq2SeqAttackInputData to train ML attackers.
|
||||
|
||||
Uses logits and losses to generate ranks and performs a random train-test
|
||||
split.
|
||||
|
||||
Args:
|
||||
attack_input_data: Original Seq2SeqAttackInputData
|
||||
test_fraction: Fraction of the dataset to include in the test split.
|
||||
balance: Whether the training and test sets for the membership inference
|
||||
attacker should have a balanced (roughly equal) number of samples from the
|
||||
training and test sets used to develop the model under attack.
|
||||
|
||||
Returns:
|
||||
AttackerData.
|
||||
"""
|
||||
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
|
||||
attack_input_data.labels_train)
|
||||
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
|
||||
attack_input_data.labels_test)
|
||||
|
||||
if balance:
|
||||
min_size = min(len(attack_input_train), len(attack_input_test))
|
||||
attack_input_train = _sample_multidimensional_array(attack_input_train,
|
||||
min_size)
|
||||
attack_input_test = _sample_multidimensional_array(attack_input_test,
|
||||
min_size)
|
||||
|
||||
features_all = np.concatenate((attack_input_train, attack_input_test))
|
||||
|
||||
# Reshape for classifying one-dimensional features
|
||||
features_all = features_all.reshape(-1, 1)
|
||||
|
||||
labels_all = np.concatenate(
|
||||
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
is_training_labels_train, is_training_labels_test = \
|
||||
model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
|
||||
return AttackerData(features_train, is_training_labels_train, features_test,
|
||||
is_training_labels_test)
|
||||
|
||||
|
||||
def _get_average_ranks(logits: Iterator[np.ndarray],
|
||||
labels: Iterator[np.ndarray]) -> np.ndarray:
|
||||
"""Returns the average rank of tokens in a batch of sequences.
|
||||
|
||||
Args:
|
||||
logits: Logits returned by a seq2seq model, dim = (num_batches,
|
||||
num_sequences, num_tokens, vocab_size).
|
||||
labels: Target labels for the seq2seq model, dim = (num_batches,
|
||||
num_sequences, num_tokens, 1).
|
||||
|
||||
Returns:
|
||||
An array of average ranks, dim = (num_batches, 1).
|
||||
Each average rank is calculated over ranks of tokens in sequences of a
|
||||
particular batch.
|
||||
"""
|
||||
ranks = []
|
||||
for batch_logits, batch_labels in zip(logits, labels):
|
||||
batch_ranks = []
|
||||
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
|
||||
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
|
||||
ranks.append(np.mean(batch_ranks))
|
||||
|
||||
return np.array(ranks)
|
||||
|
||||
|
||||
def _get_ranks_for_sequence(logits: np.ndarray,
|
||||
labels: np.ndarray) -> List[float]:
|
||||
"""Returns ranks for a sequence.
|
||||
|
||||
Args:
|
||||
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
|
||||
labels: Target labels of a single sequence, dim = (num_tokens, 1).
|
||||
|
||||
Returns:
|
||||
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
|
||||
"""
|
||||
sequence_ranks = []
|
||||
for logit, label in zip(logits, labels.astype(int)):
|
||||
rank = rankdata(-logit, method='min')[label] - 1.0
|
||||
sequence_ranks.append(rank)
|
||||
|
||||
return sequence_ranks
|
||||
|
||||
|
||||
class TrainedAttacker:
|
||||
"""Base class for training attack models."""
|
||||
model = None
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
|
||||
from tensorflow_privacy.privacy.membership_inference_attack import models
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
|
||||
|
||||
class TrainedAttackerTest(absltest.TestCase):
|
||||
|
@ -55,6 +56,66 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
expected = feature[:2] not in attack_input.logits_train
|
||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||
|
||||
def test_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=2)
|
||||
attacker_data = models.create_seq2seq_attacker_data(
|
||||
attack_input, 0.25, balance=False)
|
||||
self.assertLen(attacker_data.features_train, 3)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
def test_balanced_create_attacker_data_loss_and_logits(self):
|
||||
attack_input = AttackInputData(
|
||||
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
|
||||
|
@ -70,6 +131,71 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
expected = feature[:2] not in attack_input.logits_train
|
||||
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
|
||||
|
||||
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
|
||||
attack_input = Seq2SeqAttackInputData(
|
||||
logits_train=iter([
|
||||
np.array([
|
||||
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
|
||||
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array(
|
||||
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
|
||||
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
logits_test=iter([
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
|
||||
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([
|
||||
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_train=iter([
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
|
||||
np.array([
|
||||
np.array([0, 1], dtype=np.float32),
|
||||
np.array([1, 2], dtype=np.float32)
|
||||
],
|
||||
dtype=object)
|
||||
]),
|
||||
labels_test=iter([
|
||||
np.array([np.array([2, 1], dtype=np.float32)]),
|
||||
np.array([
|
||||
np.array([2, 0], dtype=np.float32),
|
||||
np.array([1], dtype=np.float32)
|
||||
],
|
||||
dtype=object),
|
||||
np.array([np.array([2, 1], dtype=np.float32)])
|
||||
]),
|
||||
vocab_size=3,
|
||||
train_size=3,
|
||||
test_size=3)
|
||||
attacker_data = models.create_seq2seq_attacker_data(
|
||||
attack_input, 0.33, balance=True)
|
||||
self.assertLen(attacker_data.features_train, 4)
|
||||
self.assertLen(attacker_data.features_test, 2)
|
||||
|
||||
for _, feature in enumerate(attacker_data.features_train):
|
||||
self.assertLen(feature, 1) # each feature has one average rank
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue