diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py index 2f7c205..fd549d2 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py @@ -18,7 +18,7 @@ import enum import glob import os import pickle -from typing import Any, Iterable, Union +from typing import Any, Iterable, Union, Iterator from dataclasses import dataclass import numpy as np @@ -374,6 +374,84 @@ def _append_array_shape(arr: np.array, arr_name: str, result): result.append(' %s with shape: %s,' % (arr_name, arr.shape)) +def _is_generator(gen, gen_name): + """Checks whether gen is a generator.""" + if gen is not None and not isinstance(gen, Iterator): + raise ValueError('%s should be a generator.' % gen_name) + + +@dataclass +class Seq2SeqAttackInputData: + """Input data for running an attack on seq2seq models. + + This includes only the data, and not configuration. + """ + logits_train: Iterator[np.ndarray] = None + logits_test: Iterator[np.ndarray] = None + + # Contains ground-truth token indices for the target sequences. + labels_train: Iterator[np.ndarray] = None + labels_test: Iterator[np.ndarray] = None + + # Denotes size of the target sequence vocabulary. + vocab_size: int = None + + # Train, test size = number of batches in training, test set. + # These values need to be supplied by the user as logits, labels + # are lazy loaded for seq2seq models. + train_size: int = 0 + test_size: int = 0 + + def validate(self): + """Validates the inputs.""" + + if (self.logits_train is None) != (self.logits_test is None): + raise ValueError( + 'logits_train and logits_test should both be either set or unset') + + if (self.labels_train is None) != (self.labels_test is None): + raise ValueError( + 'labels_train and labels_test should both be either set or unset') + + if self.logits_train is None or self.labels_train is None: + raise ValueError( + 'Labels, logits of training, test sets should all be set') + + if (self.vocab_size is None or self.train_size is None + or self.test_size is None): + raise ValueError( + 'vocab_size, train_size, test_size should all be set') + + if self.vocab_size is not None and not int: + raise ValueError('vocab_size should be of integer type') + + if self.train_size is not None and not int: + raise ValueError('train_size should be of integer type') + + if self.test_size is not None and not int: + raise ValueError('test_size should be of integer type') + + _is_generator(self.logits_train, 'logits_train') + _is_generator(self.logits_test, 'logits_test') + _is_generator(self.labels_train, 'labels_train') + _is_generator(self.labels_test, 'labels_test') + + def __str__(self): + """Return the shapes of variables that are not None.""" + result = ['AttackInputData('] + + if self.vocab_size is not None and self.train_size is not None: + result.append('logits_train with shape (%d, num_sequences, num_tokens, %d)' % (self.train_size, self.vocab_size)) + result.append('labels_train with shape (%d, num_sequences, num_tokens, 1)' % self.train_size) + + if self.vocab_size is not None and self.test_size is not None: + result.append('logits_test with shape (%d, num_sequences, num_tokens, %d)' % (self.test_size, self.vocab_size)) + result.append('labels_test with shape (%d, num_sequences, num_tokens, 1)' % self.test_size) + + result.append(')') + return '\n'.join(result) + + @dataclass class RocCurve: """Represents ROC curve of a membership inference classifier."""