Refactored ASTGenerator #3

Merged
aCompetentBean merged 8 commits from ayrton/refactor into main 2023-11-24 07:34:47 -07:00
2 changed files with 105 additions and 62 deletions
Showing only changes of commit 5ea6eca0ba - Show all commits

View file

@ -241,113 +241,116 @@ class AstGenerator:
# Generate the statements
self._generate_from_options(cutoffs, number_line, options)
def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY],
self.int_op_numline,
self.int_op_cutoffs,
self.int_op_options,
self.int_unary)
def generate_float_expr(self):
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
self.float_op_numline,
self.float_op_cutoffs,
self.float_op_options,
self.float_unary)
def generate_bool_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.bool_op_numline,
self.bool_op_cutoffs,
self.bool_op_options,
self.bool_unary)
def generate_char_expr(self):
self._generate_expression([GAZ_CHAR_KEY],
self.char_op_numline,
self.char_op_cutoffs,
self.char_op_options)
def generate_comp_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.comp_op_numline,
self.comp_op_cutoffs,
self.comp_op_options,
comparison=True)
def _generate_expression(self, expr_type: list[str], number_line,
cutoffs, options, unary=None, comparison: bool = False):
"""
@brief Generate an expression
@param expr_type: a list of types to be used
@param number_line: number line for probability computation
@param cutoffs: cutoffs to be used
@param options: options to be used
@param unary: a list of unary operators in options
"""
if unary is None:
unary = []
parent = self.current_ast_element
self.current_nesting_depth += 1
# Check the expression depth against settings
if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \
self.settings['block-termination-probability']:
self.generate_literal(random.choice(expr_type))
self.current_nesting_depth -= 1
return
# Generate
op = _choose_option(cutoffs, number_line, options)
self._generate_expr(comparison, expr_type, op, unary)
# Return to parent
self.current_nesting_depth -= 1
self.current_ast_element = parent
def generate_declaration(self, mut=None):
def generate_declaration(self, mut=None): # TODO change this to a bool
"""
@brief Generate a declaration
@param mut: the mutability of the variable ('const' or 'var')
"""
# Initialize the variable
parent = self.current_ast_element
decl_type = self.get_type(GAZ_VAR_TAG)
decl_args = [
("type", decl_type),
]
element = build_xml_element(decl_args, name=GAZ_DECLARATION_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
self.make_element(GAZ_DECLARATION_TAG, decl_args)
# Generate the variable
variable = self.generate_variable(decl_type, mut=mut)
self.current_ast_element.append(variable.xml)
self.current_scope.append(variable.name, variable)
self.current_scope.append(variable.name, variable) # make sure the variable is in scope
self.generate_xhs(GAZ_RHS_TAG, decl_type) # TODO add real type (decl_type)
# Generate the initialization of the variable
self.generate_xhs(GAZ_RHS_TAG, decl_type)
# Return to parent
self.current_ast_element = parent
def generate_binary(self, op, op_type):
"""
@brief Generate a binary operation
@param op: the operator
@param op_type: the type of the expression
"""
parent = self.current_ast_element
# Check if the operator is valid
if op == "":
raise ValueError("op is empty!")
args = [
("op", op),
("type", op_type),
]
element = build_xml_element(args, name=GAZ_OPERATOR_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
self.make_element(GAZ_OPERATOR_TAG, args)
# Gnereate lhs and rhs
self.generate_xhs(GAZ_LHS_TAG, op_type)
self.generate_xhs(GAZ_RHS_TAG, op_type)
# Return to parent
self.current_ast_element = parent
def generate_bracket(self, op_type):
parent = self.current_ast_element
args = [
("type", op_type),
]
element = build_xml_element(args, name=GAZ_BRACKET_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
"""
@brief Generate a bracket operation
@param op_type: the type of the expression
"""
parent = self.current_ast_element
args = [("type", op_type)]
self.make_element(GAZ_BRACKET_TAG, args)
# Generate the expression in the brackets
self.generate_xhs(GAZ_RHS_TAG, op_type)
# Return to parent
self.current_ast_element = parent
def generate_xhs(self, handedness, op_type, is_zero=False):
element = build_xml_element([], name=handedness)
"""
@brief generate a lhs or a rhs depending on handedness
@param handedness: the handedness
@param op_type: the type of the expression
@param is_zero: if the expression is zero
"""
parent = self.current_ast_element
self.current_ast_element.append(element)
self.current_ast_element = element
self.make_element(handedness, [])
self.generate_expression(op_type, is_zero=is_zero)
@ -541,6 +544,40 @@ class AstGenerator:
def generate_arg(self):
return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG))
def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY],
self.int_op_numline,
self.int_op_cutoffs,
self.int_op_options,
self.int_unary)
def generate_float_expr(self):
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
self.float_op_numline,
self.float_op_cutoffs,
self.float_op_options,
self.float_unary)
def generate_bool_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.bool_op_numline,
self.bool_op_cutoffs,
self.bool_op_options,
self.bool_unary)
def generate_char_expr(self):
self._generate_expression([GAZ_CHAR_KEY],
self.char_op_numline,
self.char_op_cutoffs,
self.char_op_options)
def generate_comp_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.comp_op_numline,
self.comp_op_cutoffs,
self.comp_op_options,
comparison=True)
def push_scope(self, xml_element: ET.Element = None):
scope = Scope(self.current_scope)
self.symbol_table.append(scope)

View file

@ -292,6 +292,12 @@ class TestGeneration(unittest.TestCase):
else:
lhs = operator.find("lhs")
rhs = operator.find("rhs")
if lhs is None:
if rhs.find("operator") is not None:
res = self.is_no_op(rhs.find("operator"))
elif rhs.find("unary") is not None:
res = self.is_no_op(rhs.find("unary"))
else:
if lhs.find("operator") is not None:
res = self.is_no_op(lhs.find("operator"))
elif lhs.find("unary") is not None:
@ -299,7 +305,7 @@ class TestGeneration(unittest.TestCase):
elif rhs.find("operator") is not None:
res = self.is_no_op(rhs.find("operator"))
elif rhs.find("unary") is not None:
res = self.is_no_op(lhs.find("unary"))
res = self.is_no_op(rhs.find("unary"))
return res