Add tests for Seq2SeqAttackInputData
This commit is contained in:
parent
d1c1746cdb
commit
4db54d9485
1 changed files with 65 additions and 1 deletions
|
@ -22,6 +22,7 @@ import numpy as np
|
|||
import pandas as pd
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||
|
@ -142,6 +143,70 @@ class AttackInputDataTest(absltest.TestCase):
|
|||
probs_test=np.array([])).validate)
|
||||
|
||||
|
||||
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||
def test_validator(self):
|
||||
valid_logits_train = iter([np.array([]), np.array([])])
|
||||
valid_logits_test = iter([np.array([]), np.array([])])
|
||||
valid_labels_train = iter([np.array([]), np.array([])])
|
||||
valid_labels_test = iter([np.array([]), np.array([])])
|
||||
|
||||
invalid_logits_train = []
|
||||
invalid_logits_test = []
|
||||
invalid_labels_train = []
|
||||
invalid_labels_test = []
|
||||
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(vocab_size=0).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(train_size=0).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(test_size=0).validate)
|
||||
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
||||
|
||||
# Tests that both logits and labels must be set.
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=valid_logits_train,
|
||||
logits_test=valid_logits_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0).validate)
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
labels_train=valid_labels_train,
|
||||
labels_test=valid_labels_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0).validate)
|
||||
|
||||
# Tests that vocab, train, test sizes must all be set.
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=valid_logits_train,
|
||||
logits_test=valid_logits_test,
|
||||
labels_train=valid_labels_train,
|
||||
labels_test=valid_labels_test).validate)
|
||||
|
||||
self.assertRaises(ValueError,
|
||||
Seq2SeqAttackInputData(
|
||||
logits_train=invalid_logits_train,
|
||||
logits_test=invalid_logits_test,
|
||||
labels_train=invalid_labels_train,
|
||||
labels_test=invalid_labels_test,
|
||||
vocab_size=0,
|
||||
train_size=0,
|
||||
test_size=0
|
||||
).validate)
|
||||
|
||||
|
||||
class RocCurveTest(absltest.TestCase):
|
||||
|
||||
def test_auc_random_classifier(self):
|
||||
|
@ -265,7 +330,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
|
|||
|
||||
|
||||
class AttackResultsTest(absltest.TestCase):
|
||||
|
||||
perfect_classifier_result: SingleAttackResult
|
||||
random_classifier_result: SingleAttackResult
|
||||
|
||||
|
|
Loading…
Reference in a new issue