gazprea-fuzzer-python/ast_generator/test/test_ast_generator.py
2023-11-18 10:59:00 -07:00

246 lines
8.3 KiB
Python

import unittest
import xml
import xml.etree.ElementTree as ET
import xml.dom.minidom
import yaml
from ast_generator.ast_generator import *
from ast_generator.gazprea_ast_grammar import *
def reachable_return(block):
return True #TODO we actually need to check this
class TestGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open("config.yaml", 'r') as stream:
props = yaml.safe_load(stream)
cls.ast_gen = AstGenerator(props)
def setUp(self):
self.ast_gen.current_nesting_depth = 0
self.ast_gen.current_control_flow_nesting_depth = 0
def test_generate_literal(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_literal('int')
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG))
self.assertEqual("int", self.ast_gen.ast.find(GAZ_LIT_TAG).get("type"))
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG).get("value"))
self.assertIsNotNone(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
def test_generate_variable(self):
out: Variable = self.ast_gen.generate_variable('int')
self.assertEqual("int", out.xml.get("type"))
self.assertIsNotNone(out.xml.get("name"))
self.assertIsNotNone(out.xml.get("mut"))
self.assertIsNotNone(ET.tostring(out.xml, 'utf-8').decode('utf-8'))
def test_generate_declaration(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration()
self.assertIsNotNone(self.ast_gen.ast.find("declaration"))
decl = self.ast_gen.ast.find("declaration")
self.assertIsNotNone(decl.find("variable"))
self.assertIsNotNone(decl.find("rhs"))
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
def test_generate_assignment(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration()
self.ast_gen.generate_assignment()
self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
decl = self.ast_gen.ast.find("assignment")
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
self.assertIsNotNone(decl.find("variable"))
self.assertIsNotNone(decl.find("rhs"))
def test_generate_bin_operation(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_binary('+', 'int')
self.assertIsNotNone(self.ast_gen.ast.find("operator"))
operator = self.ast_gen.ast.find("operator")
self.assertEqual('+', operator.get("op"))
self.assertEqual('int', operator.get("type"))
def test_generate_unary_operation(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_unary('-', 'int')
self.assertIsNotNone(self.ast_gen.ast.find("unary"))
operator = self.ast_gen.ast.find("unary")
self.assertEqual('-', operator.get("op"))
self.assertEqual('int', operator.get("type"))
def test_generate_stream(self):
for type in ["std_input", "std_output"]:
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_in_stream()
self.assertIsNotNone(self.ast_gen.ast.find("stream"))
in_stream = self.ast_gen.ast.find("stream")
self.assertEqual("std_input", in_stream.get("type"))
lad = None
for child in in_stream.iter():
lad = child.attrib
self.assertIsNotNone(lad)
def test_generate_block(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_block()
self.assertIsNotNone(self.ast_gen.ast.find("block"))
elem = None
for child in self.ast_gen.ast.iter():
elem = child.attrib
self.assertIsNotNone(elem)
def test_generate_conditional(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_conditional()
print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
self.assertIsNotNone(self.ast_gen.current_ast_element.find("conditional"))
conditional = self.ast_gen.ast.find("conditional")
# print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
self.assertIsNotNone(conditional.find("operator") or conditional.find("unary_operator") or conditional.find("literal"))
block = conditional.findall("block")
self.assertEqual(2, len(block))
def test_generate_loop(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_loop()
self.assertIsNotNone(self.ast_gen.ast.find("loop"))
loop = self.ast_gen.ast.find("loop")
# print(ET.tostring(loop, 'utf-8').decode('utf-8'))
self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal"))
block = loop.findall("block")
self.assertEqual(1, len(block))
def test_generate_routine(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_routine()
self.assertIsNotNone(self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function"))
routine = self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function")
print(ET.tostring(routine, 'utf-8').decode('utf-8'))
self.assertIsNotNone(routine.find("block"))
self.assertIsNotNone(routine.find("argument"))
def test_generate_function_ASSERT_RETURNS(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_routine(routine_type="function")
self.assertIsNotNone(self.ast_gen.ast.find("function"))
routine = self.ast_gen.ast.find("function")
print(ET.tostring(routine, 'utf-8').decode('utf-8'))
self.assertIsNotNone(routine.find("block"))
self.assertIsNotNone(routine.find("argument"))
block = routine.find("block")
print(ET.tostring(block, 'utf-8').decode('utf-8'))
rets = block.find("return")
print(rets)
self.assertLess(0, len(rets))
self.assertTrue(reachable_return(block))
def test_generate_main(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_main()
self.assertIsNotNone(self.ast_gen.ast.find("procedure"))
out = self.ast_gen.ast.find("procedure")
print(ET.tostring(out, 'utf-8').decode('utf-8'))
self.assertEqual("main", out.get("name"))
self.assertEqual("int", out.get("return_type"))
self.assertIsNotNone(out.find("block"))
block = out.find("block")
self.assertTrue(reachable_return(block))
self.assertIsNone(out.find("argument"))
def test_generate_ast(self):
self.ast_gen.generate_ast()
self.assertIsNotNone(self.ast_gen.ast)
print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
procedures = self.ast_gen.ast.findall("procedure")
self.assertLess(0, len(procedures))
main = False
for proc in procedures:
if proc.get("name") == "main":
main = True
self.assertTrue(main)
if __name__ == '__main__':
with open("config.yaml", 'r') as stream:
props = yaml.safe_load(stream)
ast_gen = AstGenerator(props)
for a in range(20):
ast_gen.generate_ast()
ast = ast_gen.ast
with open(f"xml/ast{a}.xml", 'x') as t:
dom = xml.dom.minidom.parseString(ET.tostring(ast).decode('utf-8'))
pretty: str = dom.toprettyxml()
repretty = ""
for line in pretty.split('\n'):
if line.startswith("<?xml"):
pass
else:
repretty += (line + '\n')
t.write(repretty)