diff --git a/ast_generator/ast_generator.py b/ast_generator/ast_generator.py index 19b8e7d..6067b57 100644 --- a/ast_generator/ast_generator.py +++ b/ast_generator/ast_generator.py @@ -5,6 +5,7 @@ from english_words import get_english_words_set from ast_generator.utils import * from ast_generator.utils import filter_options, _choose_option +from ast_parser.python_unparser import PythonUnparser from constants import * import keyword @@ -49,6 +50,8 @@ class AstGenerator: self.current_nesting_depth = 0 self.current_control_flow_nesting_depth = 0 + self.py_unparser = None + self._init_numlines() def _init_numlines(self): @@ -254,7 +257,8 @@ class AstGenerator: self._generate_from_options(cutoffs, number_line, options) def _generate_expression(self, expr_type: list[str], number_line, - cutoffs, options, unary=None, comparison: bool = False): + cutoffs, options, unary=None, comparison: bool = False, eval_res: bool = False, + constraint=None): """ @brief Generate an expression @@ -273,13 +277,13 @@ class AstGenerator: # Check the expression depth against settings if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \ self.settings['block-termination-probability']: - self.generate_literal(random.choice(expr_type)) + self.generate_literal(random.choice(expr_type), constraint) self.current_nesting_depth -= 1 return # Generate op = _choose_option(cutoffs, number_line, options) - self._generate_expr(comparison, expr_type, op, unary) + self._generate_expr(comparison, expr_type, op, unary, eval_res, constraint) # Return to parent self.current_nesting_depth -= 1 @@ -310,7 +314,7 @@ class AstGenerator: # Return to parent self.current_ast_element = parent - def generate_binary(self, op, op_type): + def generate_binary(self, op, op_type, eval_res=None, constraint=None): """ @brief Generate a binary operation @@ -329,13 +333,17 @@ class AstGenerator: self.make_element(GAZ_OPERATOR_TAG, args) # Gnereate lhs and rhs - self.generate_xhs(GAZ_LHS_TAG, op_type) - self.generate_xhs(GAZ_RHS_TAG, op_type) + self.generate_xhs(GAZ_LHS_TAG, op_type, constraint) + + self.py_unparser = PythonUnparser(self.current_ast_element.find(GAZ_LHS_TAG), True) + self.py_unparser.unparse() + + self.generate_xhs(GAZ_RHS_TAG, op_type, constraint=(op, eval(self.py_unparser.source))) # Return to parent self.current_ast_element = parent - def generate_bracket(self, op_type): + def generate_bracket(self, op_type, constraint=None): """ @brief Generate a bracket operation @@ -347,12 +355,12 @@ class AstGenerator: self.make_element(GAZ_BRACKET_TAG, args) # Generate the expression in the brackets - self.generate_xhs(GAZ_RHS_TAG, op_type) + self.generate_xhs(GAZ_RHS_TAG, op_type, constraint) # Return to parent self.current_ast_element = parent - def generate_xhs(self, handedness, op_type, is_zero=False): + def generate_xhs(self, handedness, op_type, is_zero=False, constraint=None): """ @brief generate a lhs or a rhs depending on handedness @@ -364,10 +372,12 @@ class AstGenerator: self.make_element(handedness, []) - self.generate_expression(op_type, is_zero=is_zero) + element = self.generate_expression(op_type, is_zero=is_zero, constraint=constraint) self.current_ast_element = parent + return element + def generate_unary(self, op, op_type=ANY_TYPE): """ @brief Generate a unary operation @@ -522,16 +532,17 @@ class AstGenerator: else: return Variable(self.get_name(GAZ_VAR_TAG), var_type, mut) - def generate_literal(self, var_type: str, value=None): + def generate_literal(self, var_type: str, value=None, constraint: tuple[str, str] | None = None): """ @brief generate a literal @param var_type: Type of the literal @param value: optional value of the literal + @param constraint: optional constraint @return: None """ if value is None: - value = self.get_value(var_type) + value = self.get_value(var_type, constraint) else: value = value @@ -567,7 +578,7 @@ class AstGenerator: self.current_scope = current_scope self.current_ast_element = current_element - def generate_expression(self, expr_type: str, is_zero=False): + def generate_expression(self, expr_type: str, is_zero=False, constraint=None): """ @brief generate an expression @@ -579,16 +590,16 @@ class AstGenerator: self.generate_literal(expr_type, value=0) return elif expr_type == GAZ_INT_KEY or expr_type == GAZ_FLOAT_KEY: - self.generate_int_expr() + self.generate_int_expr(constraint) elif expr_type == GAZ_BOOL_KEY: if random.random() < 0.5: - self.generate_bool_expr() + self.generate_bool_expr(constraint) else: - self.generate_comp_expr() + self.generate_comp_expr(constraint) elif expr_type == GAZ_CHAR_KEY: - self.generate_char_expr() + self.generate_char_expr(constraint) elif expr_type == GAZ_FLOAT_KEY: - self.generate_float_expr() + self.generate_float_expr(constraint) elif expr_type == ANY_TYPE: # TODO implement the choice of any type ty = self.get_type(GAZ_RHS_TAG) self.generate_expression(ty) @@ -613,39 +624,39 @@ class AstGenerator: def generate_arg(self): return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG)) - def generate_int_expr(self): + def generate_int_expr(self, constraint=None): self._generate_expression([GAZ_INT_KEY], self.int_op_numline, self.int_op_cutoffs, self.int_op_options, self.int_unary) - def generate_float_expr(self): + def generate_float_expr(self, constraint=None): self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY], self.float_op_numline, self.float_op_cutoffs, self.float_op_options, self.float_unary) - def generate_bool_expr(self): + def generate_bool_expr(self, constraint=None): self._generate_expression([GAZ_BOOL_KEY], self.bool_op_numline, self.bool_op_cutoffs, self.bool_op_options, self.bool_unary) - def generate_char_expr(self): + def generate_char_expr(self, constraint=None): self._generate_expression([GAZ_CHAR_KEY], self.char_op_numline, self.char_op_cutoffs, self.char_op_options) - def generate_comp_expr(self): + def generate_comp_expr(self, constraint=None): self._generate_expression([GAZ_BOOL_KEY], self.comp_op_numline, self.comp_op_cutoffs, self.comp_op_options, - comparison=True) #, evals=self.get_truth()) + comparison=True, eval_res=self.get_truth()) def push_scope(self, xml_element: ET.Element = None): scope = Scope(self.current_scope) @@ -689,7 +700,7 @@ class AstGenerator: if res < cutoffs[i]: return ops[i] # TODO everything should be fast faied - def get_value(self, type): + def get_value(self, type, constraint: tuple[str, str] | None = None): if type == GAZ_INT_KEY: if self.settings["properties"]["generate-max-int"]: return random.randint(-2147483648, 2147483647) @@ -889,16 +900,16 @@ class AstGenerator: break break - def _generate_expr(self, comparison, expr_type, op, unary): + def _generate_expr(self, comparison, expr_type, op, unary, eval_res=None, constraint=None): if op in unary: - self.generate_unary(op, random.choice(expr_type)) + self.generate_unary(op, random.choice(expr_type), constraint) elif op == GAZ_BRACKET_TAG: - self.generate_bracket(random.choice(expr_type)) + self.generate_bracket(random.choice(expr_type), constraint) elif comparison: if op in ['equality', 'inequality']: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY])) + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]), eval_res, constraint) else: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY])) + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]), eval_res, constraint) else: self.generate_binary(op, random.choice(expr_type)) @@ -923,3 +934,6 @@ class AstGenerator: self.current_ast_element.append(var.xml) self.current_ast_element = parent + + def get_truth(self): + return random.random() < self.settings['misc-weights']['conditional-true'] diff --git a/config.yaml b/config.yaml index 5ad7b7a..351c4f2 100644 --- a/config.yaml +++ b/config.yaml @@ -90,9 +90,7 @@ misc-weights: type-qualifier-weights: const: 10 var: 60 - conditional-eval: - true: 50 - false: 50 + conditional-true: 0.9 # probability for a conditional to be true block-termination-probability: 0.2 # probability for a block to terminate