Add Seq2SeqAttackInputData data structure
This commit is contained in:
parent
f0daaf085f
commit
9f07f2a871
1 changed files with 79 additions and 1 deletions
|
@ -18,7 +18,7 @@ import enum
|
|||
import glob
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Iterable, Union
|
||||
from typing import Any, Iterable, Union, Iterator
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
@ -374,6 +374,84 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
|||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||
|
||||
|
||||
def _is_generator(gen, gen_name):
|
||||
"""Checks whether gen is a generator."""
|
||||
if gen is not None and not isinstance(gen, Iterator):
|
||||
raise ValueError('%s should be a generator.' % gen_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seq2SeqAttackInputData:
|
||||
"""Input data for running an attack on seq2seq models.
|
||||
|
||||
This includes only the data, and not configuration.
|
||||
"""
|
||||
logits_train: Iterator[np.ndarray] = None
|
||||
logits_test: Iterator[np.ndarray] = None
|
||||
|
||||
# Contains ground-truth token indices for the target sequences.
|
||||
labels_train: Iterator[np.ndarray] = None
|
||||
labels_test: Iterator[np.ndarray] = None
|
||||
|
||||
# Denotes size of the target sequence vocabulary.
|
||||
vocab_size: int = None
|
||||
|
||||
# Train, test size = number of batches in training, test set.
|
||||
# These values need to be supplied by the user as logits, labels
|
||||
# are lazy loaded for seq2seq models.
|
||||
train_size: int = 0
|
||||
test_size: int = 0
|
||||
|
||||
def validate(self):
|
||||
"""Validates the inputs."""
|
||||
|
||||
if (self.logits_train is None) != (self.logits_test is None):
|
||||
raise ValueError(
|
||||
'logits_train and logits_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None) != (self.labels_test is None):
|
||||
raise ValueError(
|
||||
'labels_train and labels_test should both be either set or unset')
|
||||
|
||||
if self.logits_train is None or self.labels_train is None:
|
||||
raise ValueError(
|
||||
'Labels, logits of training, test sets should all be set')
|
||||
|
||||
if (self.vocab_size is None or self.train_size is None
|
||||
or self.test_size is None):
|
||||
raise ValueError(
|
||||
'vocab_size, train_size, test_size should all be set')
|
||||
|
||||
if self.vocab_size is not None and not int:
|
||||
raise ValueError('vocab_size should be of integer type')
|
||||
|
||||
if self.train_size is not None and not int:
|
||||
raise ValueError('train_size should be of integer type')
|
||||
|
||||
if self.test_size is not None and not int:
|
||||
raise ValueError('test_size should be of integer type')
|
||||
|
||||
_is_generator(self.logits_train, 'logits_train')
|
||||
_is_generator(self.logits_test, 'logits_test')
|
||||
_is_generator(self.labels_train, 'labels_train')
|
||||
_is_generator(self.labels_test, 'labels_test')
|
||||
|
||||
def __str__(self):
|
||||
"""Return the shapes of variables that are not None."""
|
||||
result = ['AttackInputData(']
|
||||
|
||||
if self.vocab_size is not None and self.train_size is not None:
|
||||
result.append('logits_train with shape (%d, num_sequences, num_tokens, %d)' % (self.train_size, self.vocab_size))
|
||||
result.append('labels_train with shape (%d, num_sequences, num_tokens, 1)' % self.train_size)
|
||||
|
||||
if self.vocab_size is not None and self.test_size is not None:
|
||||
result.append('logits_test with shape (%d, num_sequences, num_tokens, %d)' % (self.test_size, self.vocab_size))
|
||||
result.append('labels_test with shape (%d, num_sequences, num_tokens, 1)' % self.test_size)
|
||||
|
||||
result.append(')')
|
||||
return '\n'.join(result)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocCurve:
|
||||
"""Represents ROC curve of a membership inference classifier."""
|
||||
|
|
Loading…
Reference in a new issue