diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py index 25a9c98..23f9ba0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py @@ -44,6 +44,22 @@ class AttackInputDataTest(absltest.TestCase): np.testing.assert_equal( attack_input.get_loss_test().tolist(), [1.0, 4.0, 6.0]) + def test_validator(self): + self.assertRaises(ValueError, + AttackInputData(logits_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(labels_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(loss_train=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(logits_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(labels_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData(loss_test=np.array([])).validate) + self.assertRaises(ValueError, + AttackInputData().validate) + if __name__ == '__main__': absltest.main()