diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index d5c9d1d..80f4ca4 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -22,6 +22,7 @@ import numpy as np import pandas as pd from tensorflow_privacy.privacy.membership_inference_attack.data_structures import _log_value from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResults from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackResultsCollection from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType @@ -142,6 +143,70 @@ class AttackInputDataTest(absltest.TestCase): probs_test=np.array([])).validate) +class Seq2SeqAttackInputDataTest(absltest.TestCase): + def test_validator(self): + valid_logits_train = iter([np.array([]), np.array([])]) + valid_logits_test = iter([np.array([]), np.array([])]) + valid_labels_train = iter([np.array([]), np.array([])]) + valid_labels_test = iter([np.array([]), np.array([])]) + + invalid_logits_train = [] + invalid_logits_test = [] + invalid_labels_train = [] + invalid_labels_test = [] + + self.assertRaises(ValueError, + Seq2SeqAttackInputData(logits_train=valid_logits_train).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(labels_train=valid_labels_train).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(logits_test=valid_logits_test).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(labels_test=valid_labels_test).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(vocab_size=0).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(train_size=0).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData(test_size=0).validate) + self.assertRaises(ValueError, Seq2SeqAttackInputData().validate) + + # Tests that both logits and labels must be set. + self.assertRaises(ValueError, + Seq2SeqAttackInputData( + logits_train=valid_logits_train, + logits_test=valid_logits_test, + vocab_size=0, + train_size=0, + test_size=0).validate) + self.assertRaises(ValueError, + Seq2SeqAttackInputData( + labels_train=valid_labels_train, + labels_test=valid_labels_test, + vocab_size=0, + train_size=0, + test_size=0).validate) + + # Tests that vocab, train, test sizes must all be set. + self.assertRaises(ValueError, + Seq2SeqAttackInputData( + logits_train=valid_logits_train, + logits_test=valid_logits_test, + labels_train=valid_labels_train, + labels_test=valid_labels_test).validate) + + self.assertRaises(ValueError, + Seq2SeqAttackInputData( + logits_train=invalid_logits_train, + logits_test=invalid_logits_test, + labels_train=invalid_labels_train, + labels_test=invalid_labels_test, + vocab_size=0, + train_size=0, + test_size=0 + ).validate) + + class RocCurveTest(absltest.TestCase): def test_auc_random_classifier(self): @@ -265,7 +330,6 @@ class AttackResultsCollectionTest(absltest.TestCase): class AttackResultsTest(absltest.TestCase): - perfect_classifier_result: SingleAttackResult random_classifier_result: SingleAttackResult