Add Seq2SeqAttackInputData data structure

This commit is contained in:
amad-person 2020-11-06 16:42:31 +08:00
parent f0daaf085f
commit 9f07f2a871

View file

@ -18,7 +18,7 @@ import enum
import glob import glob
import os import os
import pickle import pickle
from typing import Any, Iterable, Union from typing import Any, Iterable, Union, Iterator
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@ -374,6 +374,84 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape)) result.append(' %s with shape: %s,' % (arr_name, arr.shape))
def _is_generator(gen, gen_name):
"""Checks whether gen is a generator."""
if gen is not None and not isinstance(gen, Iterator):
raise ValueError('%s should be a generator.' % gen_name)
@dataclass
class Seq2SeqAttackInputData:
"""Input data for running an attack on seq2seq models.
This includes only the data, and not configuration.
"""
logits_train: Iterator[np.ndarray] = None
logits_test: Iterator[np.ndarray] = None
# Contains ground-truth token indices for the target sequences.
labels_train: Iterator[np.ndarray] = None
labels_test: Iterator[np.ndarray] = None
# Denotes size of the target sequence vocabulary.
vocab_size: int = None
# Train, test size = number of batches in training, test set.
# These values need to be supplied by the user as logits, labels
# are lazy loaded for seq2seq models.
train_size: int = 0
test_size: int = 0
def validate(self):
"""Validates the inputs."""
if (self.logits_train is None) != (self.logits_test is None):
raise ValueError(
'logits_train and logits_test should both be either set or unset')
if (self.labels_train is None) != (self.labels_test is None):
raise ValueError(
'labels_train and labels_test should both be either set or unset')
if self.logits_train is None or self.labels_train is None:
raise ValueError(
'Labels, logits of training, test sets should all be set')
if (self.vocab_size is None or self.train_size is None
or self.test_size is None):
raise ValueError(
'vocab_size, train_size, test_size should all be set')
if self.vocab_size is not None and not int:
raise ValueError('vocab_size should be of integer type')
if self.train_size is not None and not int:
raise ValueError('train_size should be of integer type')
if self.test_size is not None and not int:
raise ValueError('test_size should be of integer type')
_is_generator(self.logits_train, 'logits_train')
_is_generator(self.logits_test, 'logits_test')
_is_generator(self.labels_train, 'labels_train')
_is_generator(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
result = ['AttackInputData(']
if self.vocab_size is not None and self.train_size is not None:
result.append('logits_train with shape (%d, num_sequences, num_tokens, %d)' % (self.train_size, self.vocab_size))
result.append('labels_train with shape (%d, num_sequences, num_tokens, 1)' % self.train_size)
if self.vocab_size is not None and self.test_size is not None:
result.append('logits_test with shape (%d, num_sequences, num_tokens, %d)' % (self.test_size, self.vocab_size))
result.append('labels_test with shape (%d, num_sequences, num_tokens, 1)' % self.test_size)
result.append(')')
return '\n'.join(result)
@dataclass @dataclass
class RocCurve: class RocCurve:
"""Represents ROC curve of a membership inference classifier.""" """Represents ROC curve of a membership inference classifier."""