Merge pull request #137 from amad-person:add_seq2seq_mia_attacks

PiperOrigin-RevId: 343047622
This commit is contained in:
A. Unique TensorFlower 2020-11-18 03:26:24 -08:00
commit 35a8096173
8 changed files with 1738 additions and 2 deletions

View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2019 Congzheng Song
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -18,7 +18,7 @@ import enum
import glob import glob
import os import os
import pickle import pickle
from typing import Any, Iterable, Union from typing import Any, Iterable, Union, Iterator
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@ -378,6 +378,91 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape)) result.append(' %s with shape: %s,' % (arr_name, arr.shape))
def _is_iterator(obj, obj_name):
"""Checks whether obj is a generator."""
if obj is not None and not isinstance(obj, Iterator):
raise ValueError('%s should be a generator.' % obj_name)
@dataclass
class Seq2SeqAttackInputData:
"""Input data for running an attack on seq2seq models.
This includes only the data, and not configuration.
"""
logits_train: Iterator[np.ndarray] = None
logits_test: Iterator[np.ndarray] = None
# Contains ground-truth token indices for the target sequences.
labels_train: Iterator[np.ndarray] = None
labels_test: Iterator[np.ndarray] = None
# Size of the target sequence vocabulary.
vocab_size: int = None
# Train, test size = number of batches in training, test set.
# These values need to be supplied by the user as logits, labels
# are lazy loaded for seq2seq models.
train_size: int = 0
test_size: int = 0
def validate(self):
"""Validates the inputs."""
if (self.logits_train is None) != (self.logits_test is None):
raise ValueError(
'logits_train and logits_test should both be either set or unset')
if (self.labels_train is None) != (self.labels_test is None):
raise ValueError(
'labels_train and labels_test should both be either set or unset')
if self.logits_train is None or self.labels_train is None:
raise ValueError(
'Labels, logits of training, test sets should all be set')
if (self.vocab_size is None or self.train_size is None or
self.test_size is None):
raise ValueError('vocab_size, train_size, test_size should all be set')
if self.vocab_size is not None and not int:
raise ValueError('vocab_size should be of integer type')
if self.train_size is not None and not int:
raise ValueError('train_size should be of integer type')
if self.test_size is not None and not int:
raise ValueError('test_size should be of integer type')
_is_iterator(self.logits_train, 'logits_train')
_is_iterator(self.logits_test, 'logits_test')
_is_iterator(self.labels_train, 'labels_train')
_is_iterator(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
result = ['AttackInputData(']
if self.vocab_size is not None and self.train_size is not None:
result.append(
'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
(self.train_size, self.vocab_size))
result.append(
'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
self.train_size)
if self.vocab_size is not None and self.test_size is not None:
result.append(
'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
(self.test_size, self.vocab_size))
result.append(
'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
self.test_size)
result.append(')')
return '\n'.join(result)
@dataclass @dataclass
class RocCurve: class RocCurve:
"""Represents ROC curve of a membership inference classifier.""" """Represents ROC curve of a membership inference classifier."""

View file

@ -27,6 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
@ -152,6 +153,75 @@ class AttackInputDataTest(absltest.TestCase):
probs_test=np.array([])).validate) probs_test=np.array([])).validate)
class Seq2SeqAttackInputDataTest(absltest.TestCase):
def test_validator(self):
valid_logits_train = iter([np.array([]), np.array([])])
valid_logits_test = iter([np.array([]), np.array([])])
valid_labels_train = iter([np.array([]), np.array([])])
valid_labels_test = iter([np.array([]), np.array([])])
invalid_logits_train = []
invalid_logits_test = []
invalid_labels_train = []
invalid_labels_test = []
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
# Tests that both logits and labels must be set.
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(
logits_train=valid_logits_train,
logits_test=valid_logits_test,
vocab_size=0,
train_size=0,
test_size=0).validate)
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(
labels_train=valid_labels_train,
labels_test=valid_labels_test,
vocab_size=0,
train_size=0,
test_size=0).validate)
# Tests that vocab, train, test sizes must all be set.
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(
logits_train=valid_logits_train,
logits_test=valid_logits_test,
labels_train=valid_labels_train,
labels_test=valid_labels_test).validate)
self.assertRaises(
ValueError,
Seq2SeqAttackInputData(
logits_train=invalid_logits_train,
logits_test=invalid_logits_test,
labels_train=invalid_labels_train,
labels_test=invalid_labels_test,
vocab_size=0,
train_size=0,
test_size=0).validate)
class RocCurveTest(absltest.TestCase): class RocCurveTest(absltest.TestCase):
def test_auc_random_classifier(self): def test_auc_random_classifier(self):
@ -275,7 +345,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
class AttackResultsTest(absltest.TestCase): class AttackResultsTest(absltest.TestCase):
perfect_classifier_result: SingleAttackResult perfect_classifier_result: SingleAttackResult
random_classifier_result: SingleAttackResult random_classifier_result: SingleAttackResult

View file

@ -30,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \ from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
@ -170,6 +171,54 @@ def run_attacks(attack_input: AttackInputData,
privacy_report_metadata=privacy_report_metadata) privacy_report_metadata=privacy_report_metadata)
def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
unused_report_metadata: PrivacyReportMetadata = None,
balance_attacker_training: bool = True) -> AttackResults:
"""Runs membership inference attacks on a seq2seq model.
Args:
attack_input: input data for running an attack
unused_report_metadata: the metadata of the model under attack.
balance_attacker_training: Whether the training and test sets for the
membership inference attacker should have a balanced (roughly equal)
number of samples from the training and test sets used to develop the
model under attack.
Returns:
the attack result.
"""
attack_input.validate()
# The attacker uses the average rank (a single number) of a seq2seq dataset
# record to determine membership. So only Logistic Regression is supported,
# as it makes the most sense for single-number features.
attacker = models.LogisticRegressionAttacker()
prepared_attacker_data = models.create_seq2seq_attacker_data(
attack_input, balance=balance_attacker_training)
attacker.train_model(prepared_attacker_data.features_train,
prepared_attacker_data.is_training_labels_train)
# Run the attacker on (permuted) test examples.
predictions_test = attacker.predict(prepared_attacker_data.features_test)
# Generate ROC curves with predictions.
fpr, tpr, thresholds = metrics.roc_curve(
prepared_attacker_data.is_training_labels_test, predictions_test)
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
attack_results = [
SingleAttackResult(
slice_spec=SingleSliceSpec(),
attack_type=AttackType.LOGISTIC_REGRESSION,
roc_curve=roc_curve)
]
return AttackResults(single_attack_results=attack_results)
def _compute_missing_privacy_report_metadata( def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata, metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata: attack_input: AttackInputData) -> PrivacyReportMetadata:

View file

@ -19,6 +19,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
@ -34,6 +35,68 @@ def get_test_input(n_train, n_test):
labels_test=np.array([i % 5 for i in range(n_test)])) labels_test=np.array([i % 5 for i in range(n_test)]))
def get_seq2seq_test_input(n_train,
n_test,
max_seq_in_batch,
max_tokens_in_sequence,
vocab_size,
seed=None):
"""Returns example inputs for attacks on seq2seq models."""
if seed is not None:
np.random.seed(seed=seed)
logits_train, labels_train = [], []
for _ in range(n_train):
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
batch_logits, batch_labels = _get_batch_logits_and_labels(
num_sequences, max_tokens_in_sequence, vocab_size)
logits_train.append(batch_logits)
labels_train.append(batch_labels)
logits_test, labels_test = [], []
for _ in range(n_test):
num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
batch_logits, batch_labels = _get_batch_logits_and_labels(
num_sequences, max_tokens_in_sequence, vocab_size)
logits_test.append(batch_logits)
labels_test.append(batch_labels)
return Seq2SeqAttackInputData(
logits_train=iter(logits_train),
logits_test=iter(logits_test),
labels_train=iter(labels_train),
labels_test=iter(labels_test),
vocab_size=vocab_size,
train_size=n_train,
test_size=n_test)
def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
vocab_size):
num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
num_sequences) + 1
batch_logits, batch_labels = [], []
for num_tokens in num_tokens_in_sequence:
logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
batch_logits.append(logits)
batch_labels.append(labels)
return np.array(
batch_logits, dtype=object), np.array(
batch_labels, dtype=object)
def _get_sequence_logits_and_labels(num_tokens, vocab_size):
sequence_logits = []
for _ in range(num_tokens):
token_logits = np.random.random(vocab_size)
token_logits /= token_logits.sum()
sequence_logits.append(token_logits)
sequence_labels = np.random.choice(vocab_size, num_tokens)
return np.array(
sequence_logits, dtype=np.float32), np.array(
sequence_labels, dtype=np.float32)
class RunAttacksTest(absltest.TestCase): class RunAttacksTest(absltest.TestCase):
def test_run_attacks_size(self): def test_run_attacks_size(self):
@ -97,6 +160,42 @@ class RunAttacksTest(absltest.TestCase):
# If accuracy is already present, simply return it. # If accuracy is already present, simply return it.
self.assertIsNone(mia._get_accuracy(None, labels)) self.assertIsNone(mia._get_accuracy(None, labels))
def test_run_seq2seq_attack_size(self):
result = mia.run_seq2seq_attack(
get_seq2seq_test_input(
n_train=10,
n_test=5,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=2))
self.assertLen(result.single_attack_results, 1)
def test_run_seq2seq_attack_trained_sets_attack_type(self):
result = mia.run_seq2seq_attack(
get_seq2seq_test_input(
n_train=10,
n_test=5,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=2))
seq2seq_result = list(result.single_attack_results)[0]
self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
def test_run_seq2seq_attack_calculates_correct_auc(self):
result = mia.run_seq2seq_attack(
get_seq2seq_test_input(
n_train=20,
n_test=10,
max_seq_in_batch=3,
max_tokens_in_sequence=5,
vocab_size=3,
seed=12345),
balance_attacker_training=False)
seq2seq_result = list(result.single_attack_results)[0]
np.testing.assert_almost_equal(
seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()

View file

@ -15,8 +15,11 @@
# Lint as: python3 # Lint as: python3
"""Trained models for membership inference attacks.""" """Trained models for membership inference attacks."""
from typing import Iterator, List
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
from scipy.stats import rankdata
from sklearn import ensemble from sklearn import ensemble
from sklearn import linear_model from sklearn import linear_model
from sklearn import model_selection from sklearn import model_selection
@ -24,6 +27,7 @@ from sklearn import neighbors
from sklearn import neural_network from sklearn import neural_network
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
@dataclass @dataclass
@ -110,6 +114,98 @@ def _column_stack(logits, loss):
return np.column_stack((logits, loss)) return np.column_stack((logits, loss))
def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
test_fraction: float = 0.25,
balance: bool = True) -> AttackerData:
"""Prepare Seq2SeqAttackInputData to train ML attackers.
Uses logits and losses to generate ranks and performs a random train-test
split.
Args:
attack_input_data: Original Seq2SeqAttackInputData
test_fraction: Fraction of the dataset to include in the test split.
balance: Whether the training and test sets for the membership inference
attacker should have a balanced (roughly equal) number of samples from the
training and test sets used to develop the model under attack.
Returns:
AttackerData.
"""
attack_input_train = _get_average_ranks(attack_input_data.logits_train,
attack_input_data.labels_train)
attack_input_test = _get_average_ranks(attack_input_data.logits_test,
attack_input_data.labels_test)
if balance:
min_size = min(len(attack_input_train), len(attack_input_test))
attack_input_train = _sample_multidimensional_array(attack_input_train,
min_size)
attack_input_test = _sample_multidimensional_array(attack_input_test,
min_size)
features_all = np.concatenate((attack_input_train, attack_input_test))
# Reshape for classifying one-dimensional features
features_all = features_all.reshape(-1, 1)
labels_all = np.concatenate(
((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
# Perform a train-test split
features_train, features_test, \
is_training_labels_train, is_training_labels_test = \
model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
return AttackerData(features_train, is_training_labels_train, features_test,
is_training_labels_test)
def _get_average_ranks(logits: Iterator[np.ndarray],
labels: Iterator[np.ndarray]) -> np.ndarray:
"""Returns the average rank of tokens in a batch of sequences.
Args:
logits: Logits returned by a seq2seq model, dim = (num_batches,
num_sequences, num_tokens, vocab_size).
labels: Target labels for the seq2seq model, dim = (num_batches,
num_sequences, num_tokens, 1).
Returns:
An array of average ranks, dim = (num_batches, 1).
Each average rank is calculated over ranks of tokens in sequences of a
particular batch.
"""
ranks = []
for batch_logits, batch_labels in zip(logits, labels):
batch_ranks = []
for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
ranks.append(np.mean(batch_ranks))
return np.array(ranks)
def _get_ranks_for_sequence(logits: np.ndarray,
labels: np.ndarray) -> List[float]:
"""Returns ranks for a sequence.
Args:
logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
labels: Target labels of a single sequence, dim = (num_tokens, 1).
Returns:
An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
"""
sequence_ranks = []
for logit, label in zip(logits, labels.astype(int)):
rank = rankdata(-logit, method='min')[label] - 1.0
sequence_ranks.append(rank)
return sequence_ranks
class TrainedAttacker: class TrainedAttacker:
"""Base class for training attack models.""" """Base class for training attack models."""
model = None model = None

View file

@ -19,6 +19,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import models from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
class TrainedAttackerTest(absltest.TestCase): class TrainedAttackerTest(absltest.TestCase):
@ -55,6 +56,66 @@ class TrainedAttackerTest(absltest.TestCase):
expected = feature[:2] not in attack_input.logits_train expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected) self.assertEqual(attacker_data.is_training_labels_train[i], expected)
def test_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object)
]),
vocab_size=3,
train_size=3,
test_size=2)
attacker_data = models.create_seq2seq_attacker_data(
attack_input, 0.25, balance=False)
self.assertLen(attacker_data.features_train, 3)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
def test_balanced_create_attacker_data_loss_and_logits(self): def test_balanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData( attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]), logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
@ -70,6 +131,71 @@ class TrainedAttackerTest(absltest.TestCase):
expected = feature[:2] not in attack_input.logits_train expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected) self.assertEqual(attacker_data.is_training_labels_train[i], expected)
def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
attack_input = Seq2SeqAttackInputData(
logits_train=iter([
np.array([
np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
],
dtype=object),
np.array(
[np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
dtype=object),
np.array([
np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
],
dtype=object)
]),
logits_test=iter([
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
],
dtype=object),
np.array([
np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
],
dtype=object)
]),
labels_train=iter([
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
np.array([
np.array([0, 1], dtype=np.float32),
np.array([1, 2], dtype=np.float32)
],
dtype=object)
]),
labels_test=iter([
np.array([np.array([2, 1], dtype=np.float32)]),
np.array([
np.array([2, 0], dtype=np.float32),
np.array([1], dtype=np.float32)
],
dtype=object),
np.array([np.array([2, 1], dtype=np.float32)])
]),
vocab_size=3,
train_size=3,
test_size=3)
attacker_data = models.create_seq2seq_attacker_data(
attack_input, 0.33, balance=True)
self.assertLen(attacker_data.features_train, 4)
self.assertLen(attacker_data.features_test, 2)
for _, feature in enumerate(attacker_data.features_train):
self.assertLen(feature, 1) # each feature has one average rank
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()