forked from 626_privacy/tensorflow_privacy
Add Seq2SeqAttackInputData data structure
This commit is contained in:
parent
f0daaf085f
commit
9f07f2a871
1 changed files with 79 additions and 1 deletions
|
@ -18,7 +18,7 @@ import enum
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from typing import Any, Iterable, Union
|
from typing import Any, Iterable, Union, Iterator
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -374,6 +374,84 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
|
||||||
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_generator(gen, gen_name):
|
||||||
|
"""Checks whether gen is a generator."""
|
||||||
|
if gen is not None and not isinstance(gen, Iterator):
|
||||||
|
raise ValueError('%s should be a generator.' % gen_name)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Seq2SeqAttackInputData:
|
||||||
|
"""Input data for running an attack on seq2seq models.
|
||||||
|
|
||||||
|
This includes only the data, and not configuration.
|
||||||
|
"""
|
||||||
|
logits_train: Iterator[np.ndarray] = None
|
||||||
|
logits_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Contains ground-truth token indices for the target sequences.
|
||||||
|
labels_train: Iterator[np.ndarray] = None
|
||||||
|
labels_test: Iterator[np.ndarray] = None
|
||||||
|
|
||||||
|
# Denotes size of the target sequence vocabulary.
|
||||||
|
vocab_size: int = None
|
||||||
|
|
||||||
|
# Train, test size = number of batches in training, test set.
|
||||||
|
# These values need to be supplied by the user as logits, labels
|
||||||
|
# are lazy loaded for seq2seq models.
|
||||||
|
train_size: int = 0
|
||||||
|
test_size: int = 0
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Validates the inputs."""
|
||||||
|
|
||||||
|
if (self.logits_train is None) != (self.logits_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'logits_train and logits_test should both be either set or unset')
|
||||||
|
|
||||||
|
if (self.labels_train is None) != (self.labels_test is None):
|
||||||
|
raise ValueError(
|
||||||
|
'labels_train and labels_test should both be either set or unset')
|
||||||
|
|
||||||
|
if self.logits_train is None or self.labels_train is None:
|
||||||
|
raise ValueError(
|
||||||
|
'Labels, logits of training, test sets should all be set')
|
||||||
|
|
||||||
|
if (self.vocab_size is None or self.train_size is None
|
||||||
|
or self.test_size is None):
|
||||||
|
raise ValueError(
|
||||||
|
'vocab_size, train_size, test_size should all be set')
|
||||||
|
|
||||||
|
if self.vocab_size is not None and not int:
|
||||||
|
raise ValueError('vocab_size should be of integer type')
|
||||||
|
|
||||||
|
if self.train_size is not None and not int:
|
||||||
|
raise ValueError('train_size should be of integer type')
|
||||||
|
|
||||||
|
if self.test_size is not None and not int:
|
||||||
|
raise ValueError('test_size should be of integer type')
|
||||||
|
|
||||||
|
_is_generator(self.logits_train, 'logits_train')
|
||||||
|
_is_generator(self.logits_test, 'logits_test')
|
||||||
|
_is_generator(self.labels_train, 'labels_train')
|
||||||
|
_is_generator(self.labels_test, 'labels_test')
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
"""Return the shapes of variables that are not None."""
|
||||||
|
result = ['AttackInputData(']
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.train_size is not None:
|
||||||
|
result.append('logits_train with shape (%d, num_sequences, num_tokens, %d)' % (self.train_size, self.vocab_size))
|
||||||
|
result.append('labels_train with shape (%d, num_sequences, num_tokens, 1)' % self.train_size)
|
||||||
|
|
||||||
|
if self.vocab_size is not None and self.test_size is not None:
|
||||||
|
result.append('logits_test with shape (%d, num_sequences, num_tokens, %d)' % (self.test_size, self.vocab_size))
|
||||||
|
result.append('labels_test with shape (%d, num_sequences, num_tokens, 1)' % self.test_size)
|
||||||
|
|
||||||
|
result.append(')')
|
||||||
|
return '\n'.join(result)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RocCurve:
|
class RocCurve:
|
||||||
"""Represents ROC curve of a membership inference classifier."""
|
"""Represents ROC curve of a membership inference classifier."""
|
||||||
|
|
Loading…
Reference in a new issue