Add tests for Seq2SeqAttackInputData
This commit is contained in:
parent
d1c1746cdb
commit
4db54d9485
1 changed files with 65 additions and 1 deletions
|
@ -22,6 +22,7 @@ import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
|
||||||
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
|
||||||
|
@ -142,6 +143,70 @@ class AttackInputDataTest(absltest.TestCase):
|
||||||
probs_test=np.array([])).validate)
|
probs_test=np.array([])).validate)
|
||||||
|
|
||||||
|
|
||||||
|
class Seq2SeqAttackInputDataTest(absltest.TestCase):
|
||||||
|
def test_validator(self):
|
||||||
|
valid_logits_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_logits_test = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_train = iter([np.array([]), np.array([])])
|
||||||
|
valid_labels_test = iter([np.array([]), np.array([])])
|
||||||
|
|
||||||
|
invalid_logits_train = []
|
||||||
|
invalid_logits_test = []
|
||||||
|
invalid_labels_train = []
|
||||||
|
invalid_labels_test = []
|
||||||
|
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(vocab_size=0).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(train_size=0).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(test_size=0).validate)
|
||||||
|
self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
|
||||||
|
|
||||||
|
# Tests that both logits and labels must be set.
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0).validate)
|
||||||
|
|
||||||
|
# Tests that vocab, train, test sizes must all be set.
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=valid_logits_train,
|
||||||
|
logits_test=valid_logits_test,
|
||||||
|
labels_train=valid_labels_train,
|
||||||
|
labels_test=valid_labels_test).validate)
|
||||||
|
|
||||||
|
self.assertRaises(ValueError,
|
||||||
|
Seq2SeqAttackInputData(
|
||||||
|
logits_train=invalid_logits_train,
|
||||||
|
logits_test=invalid_logits_test,
|
||||||
|
labels_train=invalid_labels_train,
|
||||||
|
labels_test=invalid_labels_test,
|
||||||
|
vocab_size=0,
|
||||||
|
train_size=0,
|
||||||
|
test_size=0
|
||||||
|
).validate)
|
||||||
|
|
||||||
|
|
||||||
class RocCurveTest(absltest.TestCase):
|
class RocCurveTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_auc_random_classifier(self):
|
def test_auc_random_classifier(self):
|
||||||
|
@ -265,7 +330,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class AttackResultsTest(absltest.TestCase):
|
class AttackResultsTest(absltest.TestCase):
|
||||||
|
|
||||||
perfect_classifier_result: SingleAttackResult
|
perfect_classifier_result: SingleAttackResult
|
||||||
random_classifier_result: SingleAttackResult
|
random_classifier_result: SingleAttackResult
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue