gazprea-fuzzer-python/ast_generator/test/test_ast_generator.py
ayrton d211131c4e IT'S ALIIIIIIIIVE
Took 1 hour 15 minutes
2023-11-20 20:28:55 -07:00

330 lines
11 KiB
Python

import unittest
import xml
import xml.etree.ElementTree as ET
import xml.dom.minidom
import yaml
from ast_generator.ast_generator import *
from ast_generator.gazprea_ast_grammar import *
def reachable_return(block):
return True #TODO we actually need to check this
class TestGeneration(unittest.TestCase):
@classmethod
def setUpClass(cls):
with open("config.yaml", 'r') as stream:
props = yaml.safe_load(stream)
cls.ast_gen = AstGenerator(props)
def setUp(self):
self.ast_gen.current_nesting_depth = 0
self.ast_gen.current_control_flow_nesting_depth = 0
def test_generate_literal(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_literal('int')
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG))
self.assertEqual("int", self.ast_gen.ast.find(GAZ_LIT_TAG).get("type"))
self.assertIsNotNone(self.ast_gen.ast.find(GAZ_LIT_TAG).get("value"))
self.assertIsNotNone(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
def test_generate_variable(self):
out: Variable = self.ast_gen.generate_variable('int')
self.assertEqual("int", out.xml.get("type"))
self.assertIsNotNone(out.xml.get("name"))
self.assertIsNotNone(out.xml.get("mut"))
self.assertIsNotNone(ET.tostring(out.xml, 'utf-8').decode('utf-8'))
def test_generate_declaration(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration()
self.assertIsNotNone(self.ast_gen.ast.find("declaration"))
decl = self.ast_gen.ast.find("declaration")
self.assertIsNotNone(decl.find("variable"))
self.assertIsNotNone(decl.find("rhs"))
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
def test_generate_assignment(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration()
self.ast_gen.generate_assignment()
self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
decl = self.ast_gen.ast.find("assignment")
# print(ET.tostring(decl, 'utf-8').decode('utf-8'))
self.assertIsNotNone(decl.find("variable"))
self.assertIsNotNone(decl.find("rhs"))
def test_generate_bin_operation(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_binary('+', 'int')
self.assertIsNotNone(self.ast_gen.ast.find("operator"))
operator = self.ast_gen.ast.find("operator")
self.assertEqual('+', operator.get("op"))
self.assertEqual('int', operator.get("type"))
def test_generate_unary_operation(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_unary('-', 'int')
self.assertIsNotNone(self.ast_gen.ast.find("unary"))
operator = self.ast_gen.ast.find("unary")
self.assertEqual('-', operator.get("op"))
self.assertEqual('int', operator.get("type"))
def test_generate_stream(self):
for type in ["std_input", "std_output"]:
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_in_stream()
self.assertIsNotNone(self.ast_gen.ast.find("stream"))
in_stream = self.ast_gen.ast.find("stream")
self.assertEqual("std_input", in_stream.get("type"))
lad = None
for child in in_stream.iter():
lad = child.attrib
self.assertIsNotNone(lad)
def test_generate_block(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_block()
self.assertIsNotNone(self.ast_gen.ast.find("block"))
elem = None
for child in self.ast_gen.ast.iter():
elem = child.attrib
self.assertIsNotNone(elem)
def test_generate_conditional(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_conditional()
# print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
self.assertIsNotNone(self.ast_gen.current_ast_element.find("conditional"))
conditional = self.ast_gen.ast.find("conditional")
# print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
self.assertIsNotNone(conditional.find("operator") or conditional.find("unary_operator") or conditional.find("literal"))
block = conditional.findall("block")
self.assertEqual(2, len(block))
def test_generate_loop(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_loop()
self.assertIsNotNone(self.ast_gen.ast.find("loop"))
loop = self.ast_gen.ast.find("loop")
# print(ET.tostring(loop, 'utf-8').decode('utf-8'))
self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal"))
block = loop.findall("block")
self.assertEqual(1, len(block))
def test_generate_routine(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_routine()
self.assertIsNotNone(self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function"))
routine = self.ast_gen.ast.find("procedure") or self.ast_gen.ast.find("function")
# print(ET.tostring(routine, 'utf-8').decode('utf-8'))
self.assertIsNotNone(routine.find("block"))
self.assertIsNotNone(routine.find("argument"))
def test_generate_function_ASSERT_RETURNS(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_routine(routine_type="function")
self.assertIsNotNone(self.ast_gen.ast.find("function"))
routine = self.ast_gen.ast.find("function")
# print(ET.tostring(routine, 'utf-8').decode('utf-8'))
self.assertIsNotNone(routine.find("block"))
self.assertIsNotNone(routine.find("argument"))
block = routine.find("block")
# print(ET.tostring(block, 'utf-8').decode('utf-8'))
rets = block.find("return")
# print(rets)
self.assertLess(0, len(rets))
self.assertTrue(reachable_return(block))
def test_generate_main(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_main()
self.assertIsNotNone(self.ast_gen.ast.find("procedure"))
out = self.ast_gen.ast.find("procedure")
# print(ET.tostring(out, 'utf-8').decode('utf-8'))
self.assertEqual("main", out.get("name"))
self.assertEqual("int", out.get("return_type"))
self.assertIsNotNone(out.find("block"))
block = out.find("block")
self.assertTrue(reachable_return(block))
self.assertIsNone(out.find("argument"))
def test_generate_ast(self):
self.ast_gen.generate_ast()
self.assertIsNotNone(self.ast_gen.ast)
# print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
procedures = self.ast_gen.ast.findall("procedure")
self.assertLess(0, len(procedures))
main = False
for proc in procedures:
if proc.get("name") == "main":
main = True
self.assertTrue(main)
def test_no_op_operation(self):
for l in range(1000):
# print("iteration: " + str(l))
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_int_real_expr()
# self.write_ast()
if self.ast_gen.ast.find("operator") is None:
l -= 1
continue
operator = self.ast_gen.ast.find("operator")
self.assertFalse(self.is_no_op(operator))
def test_create_global(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_main()
global_block = self.ast_gen.current_ast_element
global_scope = self.ast_gen.current_scope
self.assertIsNotNone(self.ast_gen.ast.find("procedure"))
self.ast_gen.current_ast_element = self.ast_gen.ast.find("procedure")
self.ast_gen.generate_global()
self.assertGreater(len(global_scope.symbols), 0)
self.assertIsNotNone(global_block.find("declaration"))
def test_generate_assignment_no_declaration(self):
for l in range(1000):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration(mut='var')
self.ast_gen.generate_assignment()
self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
decl = self.ast_gen.ast.find("assignment")
self.assertIsNone(decl.find("declaration"))
def test_failing_assignment(self):
self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast
with self.assertRaises(ValueError):
self.ast_gen.generate_assignment()
print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
self.assertIsNone(self.ast_gen.ast.find("assignment"))
def is_no_op(self, operator):
"""
recursively check if operator is a no-op
@param operator:
@return:
"""
res = False
if operator.get("op") == '':
return True
else:
lhs = operator.find("lhs")
rhs = operator.find("rhs")
if lhs.find("operator") is not None:
res = self.is_no_op(lhs.find("operator"))
elif lhs.find("unary") is not None:
res = self.is_no_op(lhs.find("unary"))
elif rhs.find("operator") is not None:
res = self.is_no_op(rhs.find("operator"))
elif rhs.find("unary") is not None:
res = self.is_no_op(lhs.find("unary"))
return res
def write_ast(self):
dom = xml.dom.minidom.parseString(ET.tostring(self.ast_gen.ast).decode('utf-8'))
pretty: str = dom.toprettyxml()
randint = random.randint(0, 1000)
print(randint)
with open("debug/ast/debug_{}.xml".format(randint), 'w') as f:
f.write(pretty)
if __name__ == '__main__':
with open("config.yaml", 'r') as stream:
props = yaml.safe_load(stream)
ast_gen = AstGenerator(props)
for a in range(20):
ast_gen.generate_ast()
ast = ast_gen.ast
with open(f"xml/ast{a}.xml", 'x') as t:
dom = xml.dom.minidom.parseString(ET.tostring(ast).decode('utf-8'))
pretty: str = dom.toprettyxml()
repretty = ""
for line in pretty.split('\n'):
if line.startswith("<?xml"):
pass
else:
repretty += (line + '\n')
t.write(repretty)