diff --git a/ast_generator/ast_generator.py b/ast_generator/ast_generator.py index 7a63e2f..cf650ed 100644 --- a/ast_generator/ast_generator.py +++ b/ast_generator/ast_generator.py @@ -151,53 +151,70 @@ class AstGenerator: self.current_ast_element = parent def generate_return(self, return_type=None, return_value=None): + """ + @brief generates a return statement + + @param return_type: the type to be returned (if None -> any) + @param return_value: value to be returned (if None -> expr[return_type]) + """ if return_type is None or return_type == GAZ_VOID_TYPE: self.current_ast_element.append(self.make_element(GAZ_RETURN_TAG, [])) - return else: - keys = [("type", return_type)] - element = self.make_element(GAZ_RETURN_TAG, keys) - self.current_ast_element.append(element) + # store the parent parent = self.current_ast_element + + # initialize element + keys = [("type", return_type)] + self.make_element(GAZ_RETURN_TAG, keys) + + # make either a literal or a random expression based on choice if return_value is None: self.generate_expression(return_type) else: self.current_ast_element.append(self.make_literal(return_type, return_value)) + # return to the parent self.current_ast_element = parent def generate_routine(self, routine_type=None): + """ + @brief generate a random routine + + @param return_type: the type to be returned (if None -> any (including void)) + """ if routine_type is None: - routine_type = self.get_routine_type() + routine_type = self.get_routine_type() # get a random type else: - routine_type = routine_type + pass + # initialize random variables args = self.generate_routine_args() - name = self.get_name(routine_type) return_type = self.get_type(routine_type) + # initialize the routine routine = Routine(name, routine_type, return_type, args) - routine_args = [ ("name", routine.name), ("return_type", routine.return_type), ] - element = build_xml_element(routine_args, name=routine.type) - self.current_ast_element.append(element) + # Generation parent = self.current_ast_element - self.current_ast_element = element - self.push_scope() + self.make_scoped_element(routine.type, routine_args) self.define_args(routine.arguments) self.generate_block(return_stmt=True, return_type=routine.return_type) - self.pop_scope() - self.current_ast_element = parent + self.exit_scoped_element(parent) def define_args(self, args): + """ + @brief Generate the argument tags in a routine + + @param args: a list of arguments + """ for arg in args: self.current_ast_element.append(arg.xml) self.current_scope.append(arg.name, arg) diff --git a/ast_generator/test/test_ast_generator.py b/ast_generator/test/test_ast_generator.py index b1293b0..19d3d22 100644 --- a/ast_generator/test/test_ast_generator.py +++ b/ast_generator/test/test_ast_generator.py @@ -60,7 +60,7 @@ class TestGeneration(unittest.TestCase): def test_generate_assignment(self): self.ast_gen.ast = ET.Element("block") self.ast_gen.current_ast_element = self.ast_gen.ast - self.ast_gen.generate_declaration() + self.ast_gen.generate_declaration(mut='var') self.ast_gen.generate_assignment() self.assertIsNotNone(self.ast_gen.ast.find("assignment")) @@ -132,17 +132,19 @@ class TestGeneration(unittest.TestCase): print(ET.tostring(conditional, 'utf-8').decode('utf-8')) + self.has_child(conditional) + + block = conditional.findall("block") + self.assertEqual(2, len(block)) + + def has_child(self, conditional): opts = ['operator', 'unary_operator', 'literal', 'brackets'] res = [] for i in opts: res.append(conditional.find(i)) res_list = list(filter(lambda x: x is not None, res)) - self.assertGreater(len(res_list), 0) - block = conditional.findall("block") - self.assertEqual(2, len(block)) - def test_generate_loop(self): self.ast_gen.ast = ET.Element("block") self.ast_gen.current_ast_element = self.ast_gen.ast @@ -153,7 +155,7 @@ class TestGeneration(unittest.TestCase): # print(ET.tostring(loop, 'utf-8').decode('utf-8')) - self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal")) + self.has_child(loop) block = loop.findall("block") self.assertEqual(1, len(block))