Refactored ASTGenerator #3

Merged
aCompetentBean merged 8 commits from ayrton/refactor into main 2023-11-24 07:34:47 -07:00
2 changed files with 39 additions and 20 deletions
Showing only changes of commit eba774dd05 - Show all commits

View file

@ -151,53 +151,70 @@ class AstGenerator:
self.current_ast_element = parent self.current_ast_element = parent
def generate_return(self, return_type=None, return_value=None): def generate_return(self, return_type=None, return_value=None):
"""
@brief generates a return statement
@param return_type: the type to be returned (if None -> any)
@param return_value: value to be returned (if None -> expr[return_type])
"""
if return_type is None or return_type == GAZ_VOID_TYPE: if return_type is None or return_type == GAZ_VOID_TYPE:
self.current_ast_element.append(self.make_element(GAZ_RETURN_TAG, [])) self.current_ast_element.append(self.make_element(GAZ_RETURN_TAG, []))
return
else: else:
keys = [("type", return_type)] # store the parent
element = self.make_element(GAZ_RETURN_TAG, keys)
self.current_ast_element.append(element)
parent = self.current_ast_element parent = self.current_ast_element
# initialize element
keys = [("type", return_type)]
self.make_element(GAZ_RETURN_TAG, keys)
# make either a literal or a random expression based on choice
if return_value is None: if return_value is None:
self.generate_expression(return_type) self.generate_expression(return_type)
else: else:
self.current_ast_element.append(self.make_literal(return_type, return_value)) self.current_ast_element.append(self.make_literal(return_type, return_value))
# return to the parent
self.current_ast_element = parent self.current_ast_element = parent
def generate_routine(self, routine_type=None): def generate_routine(self, routine_type=None):
"""
@brief generate a random routine
@param return_type: the type to be returned (if None -> any (including void))
"""
if routine_type is None: if routine_type is None:
routine_type = self.get_routine_type() routine_type = self.get_routine_type() # get a random type
else: else:
routine_type = routine_type pass
# initialize random variables
args = self.generate_routine_args() args = self.generate_routine_args()
name = self.get_name(routine_type) name = self.get_name(routine_type)
return_type = self.get_type(routine_type) return_type = self.get_type(routine_type)
# initialize the routine
routine = Routine(name, routine_type, return_type, args) routine = Routine(name, routine_type, return_type, args)
routine_args = [ routine_args = [
("name", routine.name), ("name", routine.name),
("return_type", routine.return_type), ("return_type", routine.return_type),
] ]
element = build_xml_element(routine_args, name=routine.type) # Generation
self.current_ast_element.append(element)
parent = self.current_ast_element parent = self.current_ast_element
self.current_ast_element = element self.make_scoped_element(routine.type, routine_args)
self.push_scope()
self.define_args(routine.arguments) self.define_args(routine.arguments)
self.generate_block(return_stmt=True, return_type=routine.return_type) self.generate_block(return_stmt=True, return_type=routine.return_type)
self.pop_scope()
self.current_ast_element = parent self.exit_scoped_element(parent)
def define_args(self, args): def define_args(self, args):
"""
@brief Generate the argument tags in a routine
@param args: a list of arguments
"""
for arg in args: for arg in args:
self.current_ast_element.append(arg.xml) self.current_ast_element.append(arg.xml)
self.current_scope.append(arg.name, arg) self.current_scope.append(arg.name, arg)

View file

@ -60,7 +60,7 @@ class TestGeneration(unittest.TestCase):
def test_generate_assignment(self): def test_generate_assignment(self):
self.ast_gen.ast = ET.Element("block") self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration() self.ast_gen.generate_declaration(mut='var')
self.ast_gen.generate_assignment() self.ast_gen.generate_assignment()
self.assertIsNotNone(self.ast_gen.ast.find("assignment")) self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
@ -132,17 +132,19 @@ class TestGeneration(unittest.TestCase):
print(ET.tostring(conditional, 'utf-8').decode('utf-8')) print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
self.has_child(conditional)
block = conditional.findall("block")
self.assertEqual(2, len(block))
def has_child(self, conditional):
opts = ['operator', 'unary_operator', 'literal', 'brackets'] opts = ['operator', 'unary_operator', 'literal', 'brackets']
res = [] res = []
for i in opts: for i in opts:
res.append(conditional.find(i)) res.append(conditional.find(i))
res_list = list(filter(lambda x: x is not None, res)) res_list = list(filter(lambda x: x is not None, res))
self.assertGreater(len(res_list), 0) self.assertGreater(len(res_list), 0)
block = conditional.findall("block")
self.assertEqual(2, len(block))
def test_generate_loop(self): def test_generate_loop(self):
self.ast_gen.ast = ET.Element("block") self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast self.ast_gen.current_ast_element = self.ast_gen.ast
@ -153,7 +155,7 @@ class TestGeneration(unittest.TestCase):
# print(ET.tostring(loop, 'utf-8').decode('utf-8')) # print(ET.tostring(loop, 'utf-8').decode('utf-8'))
self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal")) self.has_child(loop)
block = loop.findall("block") block = loop.findall("block")
self.assertEqual(1, len(block)) self.assertEqual(1, len(block))