331 lines
11 KiB
Python
331 lines
11 KiB
Python
import unittest
|
|
import xml
|
|
import xml.etree.ElementTree as ET
|
|
import xml.dom.minidom
|
|
|
|
import yaml
|
|
|
|
from ast_generator.ast_generator import *
|
|
from ast_generator.gazprea_ast_grammar import *
|
|
|
|
|
|
def reachable_return(block):
|
|
return True #TODO we actually need to check this
|
|
|
|
|
|
class TestGeneration(unittest.TestCase):
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
with open("config.yaml", 'r') as stream:
|
|
props = yaml.safe_load(stream)
|
|
cls.ast_gen = AstGenerator(props)
|
|
|
|
def setUp(self):
|
|
self.ast_gen.current_nesting_depth = 0
|
|
self.ast_gen.current_control_flow_nesting_depth = 0
|
|
|
|
def test_generate_literal(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_literal('int')
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG))
|
|
self.assertEqual("int", self.ast_gen.ast.find(GAZ_LIT_TAG).get("type"))
|
|
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG).get("value"))
|
|
|
|
self.assertIsNotNone(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
|
|
|
|
def test_generate_variable(self):
|
|
out: Variable = self.ast_gen.generate_variable('int')
|
|
|
|
self.assertEqual("int", out.xml.get("type"))
|
|
self.assertIsNotNone(out.xml.get("name"))
|
|
self.assertIsNotNone(out.xml.get("mut"))
|
|
|
|
self.assertIsNotNone(ET.tostring(out.xml, 'utf-8').decode('utf-8'))
|
|
|
|
def test_generate_declaration(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_declaration()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("declaration"))
|
|
|
|
decl = self.ast_gen.ast.find("declaration")
|
|
self.assertIsNotNone(decl.find("variable"))
|
|
self.assertIsNotNone(decl.find("rhs"))
|
|
|
|
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
|
|
|
|
def test_generate_assignment(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_declaration()
|
|
self.ast_gen.generate_assignment()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
|
|
|
|
decl = self.ast_gen.ast.find("assignment")
|
|
|
|
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(decl.find("variable"))
|
|
self.assertIsNotNone(decl.find("rhs"))
|
|
|
|
def test_generate_bin_operation(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_binary('+', 'int')
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("operator"))
|
|
operator = self.ast_gen.ast.find("operator")
|
|
self.assertEqual('+', operator.get("op"))
|
|
self.assertEqual('int', operator.get("type"))
|
|
|
|
def test_generate_unary_operation(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_unary('-', 'int')
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("unary"))
|
|
operator = self.ast_gen.ast.find("unary")
|
|
self.assertEqual('-', operator.get("op"))
|
|
self.assertEqual('int', operator.get("type"))
|
|
|
|
def test_generate_stream(self):
|
|
for type in ["std_input", "std_output"]:
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_in_stream()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("stream"))
|
|
in_stream = self.ast_gen.ast.find("stream")
|
|
self.assertEqual("std_input", in_stream.get("type"))
|
|
|
|
lad = None
|
|
for child in in_stream.iter():
|
|
lad = child.attrib
|
|
|
|
self.assertIsNotNone(lad)
|
|
|
|
def test_generate_block(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_block()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("block"))
|
|
|
|
elem = None
|
|
for child in self.ast_gen.ast.iter():
|
|
elem = child.attrib
|
|
self.assertIsNotNone(elem)
|
|
|
|
def test_generate_conditional(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_conditional()
|
|
|
|
# print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(self.ast_gen.current_ast_element.find("conditional"))
|
|
conditional = self.ast_gen.ast.find("conditional")
|
|
|
|
# print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(conditional.find("operator") or conditional.find("unary_operator") or conditional.find("literal"))
|
|
|
|
block = conditional.findall("block")
|
|
self.assertEqual(2, len(block))
|
|
|
|
def test_generate_loop(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_loop()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("loop"))
|
|
loop = self.ast_gen.ast.find("loop")
|
|
|
|
# print(ET.tostring(loop, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal"))
|
|
|
|
block = loop.findall("block")
|
|
self.assertEqual(1, len(block))
|
|
|
|
|
|
def test_generate_routine(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_routine()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function"))
|
|
routine = self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function")
|
|
|
|
# print(ET.tostring(routine, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(routine.find("block"))
|
|
self.assertIsNotNone(routine.find("argument"))
|
|
|
|
def test_generate_function_ASSERT_RETURNS(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_routine(routine_type="function")
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("function"))
|
|
routine = self.ast_gen.ast.find("function")
|
|
|
|
# print(ET.tostring(routine, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNotNone(routine.find("block"))
|
|
self.assertIsNotNone(routine.find("argument"))
|
|
|
|
block = routine.find("block")
|
|
# print(ET.tostring(block, 'utf-8').decode('utf-8'))
|
|
rets = block.find("return")
|
|
# print(rets)
|
|
self.assertLess(0, len(rets))
|
|
self.assertTrue(reachable_return(block))
|
|
|
|
|
|
def test_generate_main(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_main()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("procedure"))
|
|
out = self.ast_gen.ast.find("procedure")
|
|
|
|
# print(ET.tostring(out, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertEqual("main", out.get("name"))
|
|
self.assertEqual("int", out.get("return_type"))
|
|
|
|
self.assertIsNotNone(out.find("block"))
|
|
block = out.find("block")
|
|
self.assertTrue(reachable_return(block))
|
|
|
|
self.assertIsNone(out.find("argument"))
|
|
|
|
def test_generate_ast(self):
|
|
self.ast_gen.generate_ast()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast)
|
|
|
|
# print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
|
|
|
|
procedures = self.ast_gen.ast.findall("procedure")
|
|
self.assertLess(0, len(procedures))
|
|
|
|
main = False
|
|
for proc in procedures:
|
|
if proc.get("name") == "main":
|
|
main = True
|
|
self.assertTrue(main)
|
|
|
|
def test_no_op_operation(self):
|
|
for l in range(1000):
|
|
# print("iteration: " + str(l))
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_int_expr()
|
|
# self.write_ast()
|
|
|
|
if self.ast_gen.ast.find("operator") is None:
|
|
l -= 1
|
|
continue
|
|
operator = self.ast_gen.ast.find("operator")
|
|
self.assertFalse(self.is_no_op(operator))
|
|
|
|
def test_create_global(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_main()
|
|
|
|
global_block = self.ast_gen.current_ast_element
|
|
global_scope = self.ast_gen.current_scope
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("procedure"))
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast.find("procedure")
|
|
self.ast_gen.generate_global()
|
|
|
|
self.assertGreater(len(global_scope.symbols), 0)
|
|
self.assertIsNotNone(global_block.find("declaration"))
|
|
|
|
def test_generate_assignment_no_declaration(self):
|
|
for l in range(1000):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
self.ast_gen.generate_declaration(mut='var')
|
|
self.ast_gen.generate_assignment()
|
|
|
|
self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
|
|
|
|
decl = self.ast_gen.ast.find("assignment")
|
|
self.assertIsNone(decl.find("declaration"))
|
|
|
|
def test_failing_assignment(self):
|
|
self.ast_gen.ast = ET.Element("block")
|
|
self.ast_gen.current_ast_element = self.ast_gen.ast
|
|
with self.assertRaises(ValueError):
|
|
self.ast_gen.generate_assignment()
|
|
print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
|
|
|
|
self.assertIsNone(self.ast_gen.ast.find("assignment"))
|
|
|
|
def is_no_op(self, operator):
|
|
"""
|
|
recursively check if operator is a no-op
|
|
@param operator:
|
|
@return:
|
|
"""
|
|
res = False
|
|
if operator.get("op") == '':
|
|
return True
|
|
else:
|
|
lhs = operator.find("lhs")
|
|
rhs = operator.find("rhs")
|
|
if lhs.find("operator") is not None:
|
|
res = self.is_no_op(lhs.find("operator"))
|
|
elif lhs.find("unary") is not None:
|
|
res = self.is_no_op(lhs.find("unary"))
|
|
elif rhs.find("operator") is not None:
|
|
res = self.is_no_op(rhs.find("operator"))
|
|
elif rhs.find("unary") is not None:
|
|
res = self.is_no_op(lhs.find("unary"))
|
|
|
|
return res
|
|
|
|
def write_ast(self):
|
|
dom = xml.dom.minidom.parseString(ET.tostring(self.ast_gen.ast).decode('utf-8'))
|
|
pretty: str = dom.toprettyxml()
|
|
|
|
randint = random.randint(0, 1000)
|
|
print(randint)
|
|
with open("debug/ast/debug_{}.xml".format(randint), 'w') as f:
|
|
f.write(pretty)
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
with open("config.yaml", 'r') as stream:
|
|
props = yaml.safe_load(stream)
|
|
ast_gen = AstGenerator(props)
|
|
|
|
for a in range(20):
|
|
ast_gen.generate_ast()
|
|
ast = ast_gen.ast
|
|
|
|
with open(f"xml/ast{a}.xml", 'x') as t:
|
|
dom = xml.dom.minidom.parseString(ET.tostring(ast).decode('utf-8'))
|
|
pretty: str = dom.toprettyxml()
|
|
repretty = ""
|
|
for line in pretty.split('\n'):
|
|
if line.startswith("<?xml"):
|
|
pass
|
|
else:
|
|
repretty += (line + '\n')
|
|
|
|
t.write(repretty)
|