diff --git a/ast_generator/ast_generator.py b/ast_generator/ast_generator.py index e39fe18..c2ba44e 100644 --- a/ast_generator/ast_generator.py +++ b/ast_generator/ast_generator.py @@ -51,6 +51,57 @@ class AstGenerator: self.current_nesting_depth = 0 self.current_control_flow_nesting_depth = 0 + # Numberlines - For computing probabilities + self.int_op_options, self.int_op_cutoffs, self.int_op_numline = ( + self.get_numberlines('expression-weights', + ['brackets', 'arithmetic', 'unary'], + [[], [], ['negation']])) + self.int_unary = ['negation', 'noop'] + + pass + + def get_numberlines(self, settings_section: str, subsettings: list[str], excluded_values: list[list[str or None]]): + assert len(subsettings) == len(excluded_values) + + number_line = 0 + cutoffs = [] + cutoff = 0 + options = {} + option = 0 + + settings = [] + + for key, value in self.settings[settings_section].items(): + if key in subsettings and key not in excluded_values: # this check needs to be done recursively + if isinstance(value, int): + t = { + key: value + } + settings.append(t) + elif isinstance(value, dict): + settings.append(value) + else: + raise TypeError("invalid setting type. Found " + str(value) + " instead of expected int or dict") + + for v in settings: + if isinstance(v, dict): + for key, value in v.items(): + number_line += value + cutoffs.append(cutoff + value) + cutoff += value + options[option] = key + option += 1 + elif isinstance(v, int): + number_line += v + cutoffs.append(cutoff + v) + cutoff += v + options[option] = v + option += 1 + else: + raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int") + + return options, cutoffs, number_line + def generate_ast(self): """ @brief generates an AST from a grammar @@ -214,7 +265,8 @@ class AstGenerator: unary = ["negation", "noop"] - self._generate_expression([GAZ_INT_KEY, GAZ_FLOAT_KEY], number_line, cutoffs, options, unary) + self._generate_expression([GAZ_INT_KEY, GAZ_FLOAT_KEY], self.int_op_numline, self.int_op_cutoffs, + self.int_op_options, self.int_unary) def generate_bool_expr(self): # Number line @@ -259,6 +311,8 @@ class AstGenerator: if op in unary: self.generate_unary(op, random.choice(expr_type)) + elif op == GAZ_BRACKET_TAG: + self.generate_bracket(random.choice(expr_type)) else: self.generate_binary(op, random.choice(expr_type)) @@ -299,6 +353,19 @@ class AstGenerator: self.current_ast_element = parent + def generate_bracket(self, op_type): + parent = self.current_ast_element + args = [ + ("type", op_type), + ] + element = build_xml_element(args, name=GAZ_BRACKET_TAG) + self.current_ast_element.append(element) + self.current_ast_element = element + + self.generate_xhs(GAZ_RHS_TAG, op_type) + + self.current_ast_element = parent + def generate_xhs(self, handedness, op_type): element = build_xml_element([], name=handedness) parent = self.current_ast_element @@ -667,7 +734,7 @@ def build_xml_element(*keys, name): def get_op(op): - if op == 'addition': + if op == 'addition' or 'noop': return '+' elif op == 'subtraction': return '-' @@ -699,8 +766,6 @@ def get_op(op): return '-' elif op == 'not': return 'not' - elif op == 'noop': - return '+' elif op == 'concatenation': return '||' else: diff --git a/ast_parser/general_unparser.py b/ast_parser/general_unparser.py index fdfa39f..751d11d 100644 --- a/ast_parser/general_unparser.py +++ b/ast_parser/general_unparser.py @@ -72,7 +72,7 @@ class GeneralUnparser: self.unparse_node(node) def unparse_node(self, node): - if node.tag not in [GAZ_VAR_TAG, GAZ_RHS_TAG, GAZ_LHS_TAG, GAZ_LIT_TAG, GAZ_OPERATOR_TAG]: + if node.tag not in [GAZ_VAR_TAG, GAZ_RHS_TAG, GAZ_LHS_TAG, GAZ_LIT_TAG, GAZ_OPERATOR_TAG, GAZ_BRACKET_TAG]: self.source += self.indentation_character * self.indentation if node.tag == GAZ_BLOCK_TAG: @@ -103,15 +103,22 @@ class GeneralUnparser: self.unparse_conditional(node) elif node.tag == GAZ_LOOP_TAG: self.unparse_loop(node) + elif node.tag == GAZ_BRACKET_TAG: + self.unparse_brackets(node) else: raise Exception("Unknown tag: " + node.tag) def unparse_block(self, node): - self.source += f"{self.indentation * self.indentation_character}{self.block_delimiters[0]}\n" + if node.get(GAZ_TY_KEY) is None: + self.source += f"{self.indentation * self.indentation_character}" + + self.source += f"{self.block_delimiters[0]}\n" self.indentation += 4 + for child in node: self.unparse_node(child) self.indentation -= 4 + if node.get(GAZ_TY_KEY) is None: self.source += f"{self.indentation * self.indentation_character}{self.block_delimiters[1]}\n\n" elif node.get(GAZ_TY_KEY) in [GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG]: @@ -271,6 +278,11 @@ class GeneralUnparser: self.source += "{}".format(self.translate_op(element_in.get("op"))) self.unparse_xhs(element_in.find(GAZ_RHS_TAG)) + def unparse_brackets(self, element_in: ET.Element): + self.source += "(" + self.unparse_xhs(element_in.find(GAZ_RHS_TAG)) + self.source += ")" + def unparse_single_arg(self, param): return self.format_single_arg(self.translate_type(param.get(GAZ_TY_KEY)), param.get(GAZ_NAME_KEY)) diff --git a/ast_parser/python_unparser.py b/ast_parser/python_unparser.py index 40401ee..69756c6 100644 --- a/ast_parser/python_unparser.py +++ b/ast_parser/python_unparser.py @@ -24,13 +24,13 @@ def to_python_type(ty): def to_python_op(param): if param == "negation" or param == "subtraction": return "-" - elif param == "addition": + elif param == "addition" or param == "noop": return "+" elif param == "multiplication": return "*" elif param == "division": return "/" - elif param == "modulus": + elif param == "modulo": return "%" elif param == "power": return "**" diff --git a/constants.py b/constants.py index 0937ba3..55fa513 100644 --- a/constants.py +++ b/constants.py @@ -45,3 +45,4 @@ GAZ_FALSE_BLOCK_TAG = "false" GAZ_ARG_TAG = "argument" GAZ_STRING_KEY = "string" GAZ_CHAR_KEY = "char" +GAZ_BRACKET_TAG = "brackets"