diff --git a/ast_generator/ast_generator.py b/ast_generator/ast_generator.py index 19b8e7d..56d5ced 100644 --- a/ast_generator/ast_generator.py +++ b/ast_generator/ast_generator.py @@ -3,8 +3,10 @@ import warnings from english_words import get_english_words_set +from ast_generator.tiny_py_unparser import TinyPyUnparser from ast_generator.utils import * from ast_generator.utils import filter_options, _choose_option +from ast_parser.python_unparser import PythonUnparser from constants import * import keyword @@ -49,6 +51,8 @@ class AstGenerator: self.current_nesting_depth = 0 self.current_control_flow_nesting_depth = 0 + self.py_unparser = None + self._init_numlines() def _init_numlines(self): @@ -64,7 +68,7 @@ class AstGenerator: settings=self.settings)) self.bool_unary = ['not'] self.float_op_options, self.float_op_cutoffs, self.float_op_numline = ( - get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], [], ['not']], + get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], ['modulo'], ['not']], self.settings)) self.float_unary = ['negation', 'noop'] self.char_op_options, self.char_op_cutoffs, self.char_op_numline = ( @@ -254,7 +258,8 @@ class AstGenerator: self._generate_from_options(cutoffs, number_line, options) def _generate_expression(self, expr_type: list[str], number_line, - cutoffs, options, unary=None, comparison: bool = False): + cutoffs, options, unary=None, comparison: bool = False, eval_res: bool = False, + constraint=None): """ @brief Generate an expression @@ -273,13 +278,13 @@ class AstGenerator: # Check the expression depth against settings if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \ self.settings['block-termination-probability']: - self.generate_literal(random.choice(expr_type)) + self.generate_literal(random.choice(expr_type), constraint) self.current_nesting_depth -= 1 return # Generate op = _choose_option(cutoffs, number_line, options) - self._generate_expr(comparison, expr_type, op, unary) + self._generate_expr(comparison, expr_type, op, unary, eval_res, constraint) # Return to parent self.current_nesting_depth -= 1 @@ -310,7 +315,7 @@ class AstGenerator: # Return to parent self.current_ast_element = parent - def generate_binary(self, op, op_type): + def generate_binary(self, op, op_type, eval_res=None, constraint=None): """ @brief Generate a binary operation @@ -329,13 +334,19 @@ class AstGenerator: self.make_element(GAZ_OPERATOR_TAG, args) # Gnereate lhs and rhs - self.generate_xhs(GAZ_LHS_TAG, op_type) - self.generate_xhs(GAZ_RHS_TAG, op_type) + self.generate_xhs(GAZ_LHS_TAG, op_type, constraint) + + self.py_unparser = TinyPyUnparser(self.current_ast_element.find(GAZ_LHS_TAG), True) + self.py_unparser.unparse() + print(self.py_unparser.source) + + + self.generate_xhs(GAZ_RHS_TAG, op_type, constraint=(op, eval(self.py_unparser.source))) # Return to parent self.current_ast_element = parent - def generate_bracket(self, op_type): + def generate_bracket(self, op_type, constraint=None): """ @brief Generate a bracket operation @@ -347,12 +358,12 @@ class AstGenerator: self.make_element(GAZ_BRACKET_TAG, args) # Generate the expression in the brackets - self.generate_xhs(GAZ_RHS_TAG, op_type) + self.generate_xhs(GAZ_RHS_TAG, op_type, constraint) # Return to parent self.current_ast_element = parent - def generate_xhs(self, handedness, op_type, is_zero=False): + def generate_xhs(self, handedness, op_type, is_zero=False, constraint=None): """ @brief generate a lhs or a rhs depending on handedness @@ -364,11 +375,13 @@ class AstGenerator: self.make_element(handedness, []) - self.generate_expression(op_type, is_zero=is_zero) + element = self.generate_expression(op_type, is_zero=is_zero, constraint=constraint) self.current_ast_element = parent - def generate_unary(self, op, op_type=ANY_TYPE): + return element + + def generate_unary(self, op, op_type=ANY_TYPE, constraint=None): """ @brief Generate a unary operation @@ -381,7 +394,7 @@ class AstGenerator: ] self.make_element(GAZ_UNARY_OPERATOR_TAG, args) - self.generate_xhs(GAZ_RHS_TAG, op_type) + self.generate_xhs(GAZ_RHS_TAG, op_type, constraint) self.current_ast_element = parent @@ -522,16 +535,17 @@ class AstGenerator: else: return Variable(self.get_name(GAZ_VAR_TAG), var_type, mut) - def generate_literal(self, var_type: str, value=None): + def generate_literal(self, var_type: str, value=None, constraint: tuple[str, str] | None = None): """ @brief generate a literal @param var_type: Type of the literal @param value: optional value of the literal + @param constraint: optional constraint @return: None """ if value is None: - value = self.get_value(var_type) + value = self.get_value(var_type, constraint) else: value = value @@ -567,7 +581,7 @@ class AstGenerator: self.current_scope = current_scope self.current_ast_element = current_element - def generate_expression(self, expr_type: str, is_zero=False): + def generate_expression(self, expr_type: str, is_zero=False, constraint=None): """ @brief generate an expression @@ -579,16 +593,16 @@ class AstGenerator: self.generate_literal(expr_type, value=0) return elif expr_type == GAZ_INT_KEY or expr_type == GAZ_FLOAT_KEY: - self.generate_int_expr() + self.generate_int_expr(constraint) elif expr_type == GAZ_BOOL_KEY: if random.random() < 0.5: - self.generate_bool_expr() + self.generate_bool_expr(constraint) else: - self.generate_comp_expr() + self.generate_comp_expr(constraint) elif expr_type == GAZ_CHAR_KEY: - self.generate_char_expr() + self.generate_char_expr(constraint) elif expr_type == GAZ_FLOAT_KEY: - self.generate_float_expr() + self.generate_float_expr(constraint) elif expr_type == ANY_TYPE: # TODO implement the choice of any type ty = self.get_type(GAZ_RHS_TAG) self.generate_expression(ty) @@ -613,39 +627,39 @@ class AstGenerator: def generate_arg(self): return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG)) - def generate_int_expr(self): + def generate_int_expr(self, constraint=None): self._generate_expression([GAZ_INT_KEY], self.int_op_numline, self.int_op_cutoffs, self.int_op_options, self.int_unary) - def generate_float_expr(self): + def generate_float_expr(self, constraint=None): self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY], self.float_op_numline, self.float_op_cutoffs, self.float_op_options, self.float_unary) - def generate_bool_expr(self): + def generate_bool_expr(self, constraint=None): self._generate_expression([GAZ_BOOL_KEY], self.bool_op_numline, self.bool_op_cutoffs, self.bool_op_options, self.bool_unary) - def generate_char_expr(self): + def generate_char_expr(self, constraint=None): self._generate_expression([GAZ_CHAR_KEY], self.char_op_numline, self.char_op_cutoffs, self.char_op_options) - def generate_comp_expr(self): + def generate_comp_expr(self, constraint=None): self._generate_expression([GAZ_BOOL_KEY], self.comp_op_numline, self.comp_op_cutoffs, self.comp_op_options, - comparison=True) #, evals=self.get_truth()) + comparison=True, eval_res=self.get_truth()) def push_scope(self, xml_element: ET.Element = None): scope = Scope(self.current_scope) @@ -689,7 +703,7 @@ class AstGenerator: if res < cutoffs[i]: return ops[i] # TODO everything should be fast faied - def get_value(self, type): + def get_value(self, type, constraint: tuple[str, str] | None = None): if type == GAZ_INT_KEY: if self.settings["properties"]["generate-max-int"]: return random.randint(-2147483648, 2147483647) @@ -889,16 +903,16 @@ class AstGenerator: break break - def _generate_expr(self, comparison, expr_type, op, unary): + def _generate_expr(self, comparison, expr_type, op, unary, eval_res=None, constraint=None): if op in unary: - self.generate_unary(op, random.choice(expr_type)) + self.generate_unary(op, random.choice(expr_type), constraint) elif op == GAZ_BRACKET_TAG: - self.generate_bracket(random.choice(expr_type)) + self.generate_bracket(random.choice(expr_type), constraint) elif comparison: if op in ['equality', 'inequality']: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY])) + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]), eval_res, constraint) else: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY])) + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]), eval_res, constraint) else: self.generate_binary(op, random.choice(expr_type)) @@ -923,3 +937,6 @@ class AstGenerator: self.current_ast_element.append(var.xml) self.current_ast_element = parent + + def get_truth(self): + return random.random() < self.settings['misc-weights']['conditional-true'] diff --git a/ast_generator/tiny_py_unparser.py b/ast_generator/tiny_py_unparser.py new file mode 100644 index 0000000..2bd2f33 --- /dev/null +++ b/ast_generator/tiny_py_unparser.py @@ -0,0 +1,66 @@ +from ast_parser.general_unparser import GeneralUnparser +from ast_parser.python_unparser import PythonUnparser, to_python_type, to_python_op +from constants import GAZ_TY_KEY, GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG +import xml.etree.ElementTree as ET + + +class TinyPyUnparser(GeneralUnparser): + def __init__(self, ast: ET.Element, debug=False): + super().__init__(ast, debug, + endline='\n', + outstream_begin_delimiter="gprint(", + outstream_end_delimiter=", end='')", + function_return_type_indicator_predicate="->", + loop_start_delimiter="while ", + loop_end_delimiter=":", + conditional_case_delimiter="elif ", + conditional_start_delimiter="if ", + conditional_else_delimiter="else:", + conditional_end_delimiter=":", + block_start_delimiter="", + block_end_delimiter="", # TODO can this contain the pass? + strip_conditionals=True) + + def format_variable(self, mut, ty, name, declaration: bool = False): + if declaration: + return "{}: {}".format(name, ty) + else: + return "{}".format(name) + + def translate_value(self, val): + return str(val) + + def translate_op(self, param, ty=None): + return to_python_op(param, ty) + + def translate_type(self, ty): + return to_python_type(ty) + + def function_declaration(self, xml_tag, args, name, return_type): + return "def {}{} {}:".format( + name, + args, + return_type, + ) + + def format_single_arg(self, ty, name): + return "{}: {}".format(name, ty) + + def unparse_block(self, node): + # super().unparse_block(node) + self.source += f"{self.block_delimiters[0]}\n" + self.indentation += 4 + for child in node: + self.unparse_node(child) + self.source += self.indentation_character * self.indentation + "pass\n" + self.indentation -= 4 + if node.get(GAZ_TY_KEY) is None: + self.source += f"{self.block_delimiters[1]}\n\n" + elif node.get(GAZ_TY_KEY) in [GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG]: + self.source += f"{self.block_delimiters[1]}" + + def unparse(self): + super().unparse() + + def setup(self): + pass \ No newline at end of file diff --git a/config.yaml b/config.yaml index 5ad7b7a..351c4f2 100644 --- a/config.yaml +++ b/config.yaml @@ -90,9 +90,7 @@ misc-weights: type-qualifier-weights: const: 10 var: 60 - conditional-eval: - true: 50 - false: 50 + conditional-true: 0.9 # probability for a conditional to be true block-termination-probability: 0.2 # probability for a block to terminate