Add tests for Seq2SeqAttackInputData

This commit is contained in:
amad-person 2020-11-06 16:46:57 +08:00
parent d1c1746cdb
commit 4db54d9485

View file

@ -22,6 +22,7 @@ import numpy as np
import pandas as pd
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
@ -142,6 +143,70 @@ class AttackInputDataTest(absltest.TestCase):
probs_test=np.array([])).validate)
class Seq2SeqAttackInputDataTest(absltest.TestCase):
def test_validator(self):
valid_logits_train = iter([np.array([]), np.array([])])
valid_logits_test = iter([np.array([]), np.array([])])
valid_labels_train = iter([np.array([]), np.array([])])
valid_labels_test = iter([np.array([]), np.array([])])
invalid_logits_train = []
invalid_logits_test = []
invalid_labels_train = []
invalid_labels_test = []
self.assertRaises(ValueError,
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(vocab_size=0).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(train_size=0).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(test_size=0).validate)
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
# Tests that both logits and labels must be set.
self.assertRaises(ValueError,
Seq2SeqAttackInputData(
logits_train=valid_logits_train,
logits_test=valid_logits_test,
vocab_size=0,
train_size=0,
test_size=0).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(
labels_train=valid_labels_train,
labels_test=valid_labels_test,
vocab_size=0,
train_size=0,
test_size=0).validate)
# Tests that vocab, train, test sizes must all be set.
self.assertRaises(ValueError,
Seq2SeqAttackInputData(
logits_train=valid_logits_train,
logits_test=valid_logits_test,
labels_train=valid_labels_train,
labels_test=valid_labels_test).validate)
self.assertRaises(ValueError,
Seq2SeqAttackInputData(
logits_train=invalid_logits_train,
logits_test=invalid_logits_test,
labels_train=invalid_labels_train,
labels_test=invalid_labels_test,
vocab_size=0,
train_size=0,
test_size=0
).validate)
class RocCurveTest(absltest.TestCase):
def test_auc_random_classifier(self):
@ -265,7 +330,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
class AttackResultsTest(absltest.TestCase):
perfect_classifier_result: SingleAttackResult
random_classifier_result: SingleAttackResult