Add Seq2SeqAttackInputData data structure

This commit is contained in:
amad-person 2020-11-06 16:42:31 +08:00
parent f0daaf085f
commit 9f07f2a871

View file

@ -18,7 +18,7 @@ import enum
import glob
import os
import pickle
from typing import Any, Iterable, Union
from typing import Any, Iterable, Union, Iterator
from dataclasses import dataclass
import numpy as np
@ -374,6 +374,84 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
def _is_generator(gen, gen_name):
"""Checks whether gen is a generator."""
if gen is not None and not isinstance(gen, Iterator):
raise ValueError('%s should be a generator.' % gen_name)
@dataclass
class Seq2SeqAttackInputData:
"""Input data for running an attack on seq2seq models.
This includes only the data, and not configuration.
"""
logits_train: Iterator[np.ndarray] = None
logits_test: Iterator[np.ndarray] = None
# Contains ground-truth token indices for the target sequences.
labels_train: Iterator[np.ndarray] = None
labels_test: Iterator[np.ndarray] = None
# Denotes size of the target sequence vocabulary.
vocab_size: int = None
# Train, test size = number of batches in training, test set.
# These values need to be supplied by the user as logits, labels
# are lazy loaded for seq2seq models.
train_size: int = 0
test_size: int = 0
def validate(self):
"""Validates the inputs."""
if (self.logits_train is None) != (self.logits_test is None):
raise ValueError(
'logits_train and logits_test should both be either set or unset')
if (self.labels_train is None) != (self.labels_test is None):
raise ValueError(
'labels_train and labels_test should both be either set or unset')
if self.logits_train is None or self.labels_train is None:
raise ValueError(
'Labels, logits of training, test sets should all be set')
if (self.vocab_size is None or self.train_size is None
or self.test_size is None):
raise ValueError(
'vocab_size, train_size, test_size should all be set')
if self.vocab_size is not None and not int:
raise ValueError('vocab_size should be of integer type')
if self.train_size is not None and not int:
raise ValueError('train_size should be of integer type')
if self.test_size is not None and not int:
raise ValueError('test_size should be of integer type')
_is_generator(self.logits_train, 'logits_train')
_is_generator(self.logits_test, 'logits_test')
_is_generator(self.labels_train, 'labels_train')
_is_generator(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
result = ['AttackInputData(']
if self.vocab_size is not None and self.train_size is not None:
result.append('logits_train with shape (%d, num_sequences, num_tokens, %d)' % (self.train_size, self.vocab_size))
result.append('labels_train with shape (%d, num_sequences, num_tokens, 1)' % self.train_size)
if self.vocab_size is not None and self.test_size is not None:
result.append('logits_test with shape (%d, num_sequences, num_tokens, %d)' % (self.test_size, self.vocab_size))
result.append('labels_test with shape (%d, num_sequences, num_tokens, 1)' % self.test_size)
result.append(')')
return '\n'.join(result)
@dataclass
class RocCurve:
"""Represents ROC curve of a membership inference classifier."""