Merge pull request 'Conditional Truth' (#5) from ayrton/conditional-eval into main
Reviewed-on: #5
This commit is contained in:
commit
0bd4ed0606
3 changed files with 117 additions and 36 deletions
|
@ -3,8 +3,10 @@ import warnings
|
||||||
|
|
||||||
from english_words import get_english_words_set
|
from english_words import get_english_words_set
|
||||||
|
|
||||||
|
from ast_generator.tiny_py_unparser import TinyPyUnparser
|
||||||
from ast_generator.utils import *
|
from ast_generator.utils import *
|
||||||
from ast_generator.utils import filter_options, _choose_option
|
from ast_generator.utils import filter_options, _choose_option
|
||||||
|
from ast_parser.python_unparser import PythonUnparser
|
||||||
from constants import *
|
from constants import *
|
||||||
|
|
||||||
import keyword
|
import keyword
|
||||||
|
@ -49,6 +51,8 @@ class AstGenerator:
|
||||||
self.current_nesting_depth = 0
|
self.current_nesting_depth = 0
|
||||||
self.current_control_flow_nesting_depth = 0
|
self.current_control_flow_nesting_depth = 0
|
||||||
|
|
||||||
|
self.py_unparser = None
|
||||||
|
|
||||||
self._init_numlines()
|
self._init_numlines()
|
||||||
|
|
||||||
def _init_numlines(self):
|
def _init_numlines(self):
|
||||||
|
@ -64,7 +68,7 @@ class AstGenerator:
|
||||||
settings=self.settings))
|
settings=self.settings))
|
||||||
self.bool_unary = ['not']
|
self.bool_unary = ['not']
|
||||||
self.float_op_options, self.float_op_cutoffs, self.float_op_numline = (
|
self.float_op_options, self.float_op_cutoffs, self.float_op_numline = (
|
||||||
get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], [], ['not']],
|
get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], ['modulo'], ['not']],
|
||||||
self.settings))
|
self.settings))
|
||||||
self.float_unary = ['negation', 'noop']
|
self.float_unary = ['negation', 'noop']
|
||||||
self.char_op_options, self.char_op_cutoffs, self.char_op_numline = (
|
self.char_op_options, self.char_op_cutoffs, self.char_op_numline = (
|
||||||
|
@ -254,7 +258,8 @@ class AstGenerator:
|
||||||
self._generate_from_options(cutoffs, number_line, options)
|
self._generate_from_options(cutoffs, number_line, options)
|
||||||
|
|
||||||
def _generate_expression(self, expr_type: list[str], number_line,
|
def _generate_expression(self, expr_type: list[str], number_line,
|
||||||
cutoffs, options, unary=None, comparison: bool = False):
|
cutoffs, options, unary=None, comparison: bool = False, eval_res: bool = False,
|
||||||
|
constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief Generate an expression
|
@brief Generate an expression
|
||||||
|
|
||||||
|
@ -273,13 +278,13 @@ class AstGenerator:
|
||||||
# Check the expression depth against settings
|
# Check the expression depth against settings
|
||||||
if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \
|
if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \
|
||||||
self.settings['block-termination-probability']:
|
self.settings['block-termination-probability']:
|
||||||
self.generate_literal(random.choice(expr_type))
|
self.generate_literal(random.choice(expr_type), constraint)
|
||||||
self.current_nesting_depth -= 1
|
self.current_nesting_depth -= 1
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generate
|
# Generate
|
||||||
op = _choose_option(cutoffs, number_line, options)
|
op = _choose_option(cutoffs, number_line, options)
|
||||||
self._generate_expr(comparison, expr_type, op, unary)
|
self._generate_expr(comparison, expr_type, op, unary, eval_res, constraint)
|
||||||
|
|
||||||
# Return to parent
|
# Return to parent
|
||||||
self.current_nesting_depth -= 1
|
self.current_nesting_depth -= 1
|
||||||
|
@ -310,7 +315,7 @@ class AstGenerator:
|
||||||
# Return to parent
|
# Return to parent
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
def generate_binary(self, op, op_type):
|
def generate_binary(self, op, op_type, eval_res=None, constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief Generate a binary operation
|
@brief Generate a binary operation
|
||||||
|
|
||||||
|
@ -329,13 +334,19 @@ class AstGenerator:
|
||||||
self.make_element(GAZ_OPERATOR_TAG, args)
|
self.make_element(GAZ_OPERATOR_TAG, args)
|
||||||
|
|
||||||
# Gnereate lhs and rhs
|
# Gnereate lhs and rhs
|
||||||
self.generate_xhs(GAZ_LHS_TAG, op_type)
|
self.generate_xhs(GAZ_LHS_TAG, op_type, constraint)
|
||||||
self.generate_xhs(GAZ_RHS_TAG, op_type)
|
|
||||||
|
self.py_unparser = TinyPyUnparser(self.current_ast_element.find(GAZ_LHS_TAG), True)
|
||||||
|
self.py_unparser.unparse()
|
||||||
|
print(self.py_unparser.source)
|
||||||
|
|
||||||
|
|
||||||
|
self.generate_xhs(GAZ_RHS_TAG, op_type, constraint=(op, eval(self.py_unparser.source)))
|
||||||
|
|
||||||
# Return to parent
|
# Return to parent
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
def generate_bracket(self, op_type):
|
def generate_bracket(self, op_type, constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief Generate a bracket operation
|
@brief Generate a bracket operation
|
||||||
|
|
||||||
|
@ -347,12 +358,12 @@ class AstGenerator:
|
||||||
self.make_element(GAZ_BRACKET_TAG, args)
|
self.make_element(GAZ_BRACKET_TAG, args)
|
||||||
|
|
||||||
# Generate the expression in the brackets
|
# Generate the expression in the brackets
|
||||||
self.generate_xhs(GAZ_RHS_TAG, op_type)
|
self.generate_xhs(GAZ_RHS_TAG, op_type, constraint)
|
||||||
|
|
||||||
# Return to parent
|
# Return to parent
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
def generate_xhs(self, handedness, op_type, is_zero=False):
|
def generate_xhs(self, handedness, op_type, is_zero=False, constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief generate a lhs or a rhs depending on handedness
|
@brief generate a lhs or a rhs depending on handedness
|
||||||
|
|
||||||
|
@ -364,11 +375,13 @@ class AstGenerator:
|
||||||
|
|
||||||
self.make_element(handedness, [])
|
self.make_element(handedness, [])
|
||||||
|
|
||||||
self.generate_expression(op_type, is_zero=is_zero)
|
element = self.generate_expression(op_type, is_zero=is_zero, constraint=constraint)
|
||||||
|
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
def generate_unary(self, op, op_type=ANY_TYPE):
|
return element
|
||||||
|
|
||||||
|
def generate_unary(self, op, op_type=ANY_TYPE, constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief Generate a unary operation
|
@brief Generate a unary operation
|
||||||
|
|
||||||
|
@ -381,7 +394,7 @@ class AstGenerator:
|
||||||
]
|
]
|
||||||
self.make_element(GAZ_UNARY_OPERATOR_TAG, args)
|
self.make_element(GAZ_UNARY_OPERATOR_TAG, args)
|
||||||
|
|
||||||
self.generate_xhs(GAZ_RHS_TAG, op_type)
|
self.generate_xhs(GAZ_RHS_TAG, op_type, constraint)
|
||||||
|
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
|
@ -522,16 +535,17 @@ class AstGenerator:
|
||||||
else:
|
else:
|
||||||
return Variable(self.get_name(GAZ_VAR_TAG), var_type, mut)
|
return Variable(self.get_name(GAZ_VAR_TAG), var_type, mut)
|
||||||
|
|
||||||
def generate_literal(self, var_type: str, value=None):
|
def generate_literal(self, var_type: str, value=None, constraint: tuple[str, str] | None = None):
|
||||||
"""
|
"""
|
||||||
@brief generate a literal
|
@brief generate a literal
|
||||||
|
|
||||||
@param var_type: Type of the literal
|
@param var_type: Type of the literal
|
||||||
@param value: optional value of the literal
|
@param value: optional value of the literal
|
||||||
|
@param constraint: optional constraint
|
||||||
@return: None
|
@return: None
|
||||||
"""
|
"""
|
||||||
if value is None:
|
if value is None:
|
||||||
value = self.get_value(var_type)
|
value = self.get_value(var_type, constraint)
|
||||||
else:
|
else:
|
||||||
value = value
|
value = value
|
||||||
|
|
||||||
|
@ -567,7 +581,7 @@ class AstGenerator:
|
||||||
self.current_scope = current_scope
|
self.current_scope = current_scope
|
||||||
self.current_ast_element = current_element
|
self.current_ast_element = current_element
|
||||||
|
|
||||||
def generate_expression(self, expr_type: str, is_zero=False):
|
def generate_expression(self, expr_type: str, is_zero=False, constraint=None):
|
||||||
"""
|
"""
|
||||||
@brief generate an expression
|
@brief generate an expression
|
||||||
|
|
||||||
|
@ -579,16 +593,16 @@ class AstGenerator:
|
||||||
self.generate_literal(expr_type, value=0)
|
self.generate_literal(expr_type, value=0)
|
||||||
return
|
return
|
||||||
elif expr_type == GAZ_INT_KEY or expr_type == GAZ_FLOAT_KEY:
|
elif expr_type == GAZ_INT_KEY or expr_type == GAZ_FLOAT_KEY:
|
||||||
self.generate_int_expr()
|
self.generate_int_expr(constraint)
|
||||||
elif expr_type == GAZ_BOOL_KEY:
|
elif expr_type == GAZ_BOOL_KEY:
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
self.generate_bool_expr()
|
self.generate_bool_expr(constraint)
|
||||||
else:
|
else:
|
||||||
self.generate_comp_expr()
|
self.generate_comp_expr(constraint)
|
||||||
elif expr_type == GAZ_CHAR_KEY:
|
elif expr_type == GAZ_CHAR_KEY:
|
||||||
self.generate_char_expr()
|
self.generate_char_expr(constraint)
|
||||||
elif expr_type == GAZ_FLOAT_KEY:
|
elif expr_type == GAZ_FLOAT_KEY:
|
||||||
self.generate_float_expr()
|
self.generate_float_expr(constraint)
|
||||||
elif expr_type == ANY_TYPE: # TODO implement the choice of any type
|
elif expr_type == ANY_TYPE: # TODO implement the choice of any type
|
||||||
ty = self.get_type(GAZ_RHS_TAG)
|
ty = self.get_type(GAZ_RHS_TAG)
|
||||||
self.generate_expression(ty)
|
self.generate_expression(ty)
|
||||||
|
@ -613,39 +627,39 @@ class AstGenerator:
|
||||||
def generate_arg(self):
|
def generate_arg(self):
|
||||||
return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG))
|
return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG))
|
||||||
|
|
||||||
def generate_int_expr(self):
|
def generate_int_expr(self, constraint=None):
|
||||||
self._generate_expression([GAZ_INT_KEY],
|
self._generate_expression([GAZ_INT_KEY],
|
||||||
self.int_op_numline,
|
self.int_op_numline,
|
||||||
self.int_op_cutoffs,
|
self.int_op_cutoffs,
|
||||||
self.int_op_options,
|
self.int_op_options,
|
||||||
self.int_unary)
|
self.int_unary)
|
||||||
|
|
||||||
def generate_float_expr(self):
|
def generate_float_expr(self, constraint=None):
|
||||||
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
|
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
|
||||||
self.float_op_numline,
|
self.float_op_numline,
|
||||||
self.float_op_cutoffs,
|
self.float_op_cutoffs,
|
||||||
self.float_op_options,
|
self.float_op_options,
|
||||||
self.float_unary)
|
self.float_unary)
|
||||||
|
|
||||||
def generate_bool_expr(self):
|
def generate_bool_expr(self, constraint=None):
|
||||||
self._generate_expression([GAZ_BOOL_KEY],
|
self._generate_expression([GAZ_BOOL_KEY],
|
||||||
self.bool_op_numline,
|
self.bool_op_numline,
|
||||||
self.bool_op_cutoffs,
|
self.bool_op_cutoffs,
|
||||||
self.bool_op_options,
|
self.bool_op_options,
|
||||||
self.bool_unary)
|
self.bool_unary)
|
||||||
|
|
||||||
def generate_char_expr(self):
|
def generate_char_expr(self, constraint=None):
|
||||||
self._generate_expression([GAZ_CHAR_KEY],
|
self._generate_expression([GAZ_CHAR_KEY],
|
||||||
self.char_op_numline,
|
self.char_op_numline,
|
||||||
self.char_op_cutoffs,
|
self.char_op_cutoffs,
|
||||||
self.char_op_options)
|
self.char_op_options)
|
||||||
|
|
||||||
def generate_comp_expr(self):
|
def generate_comp_expr(self, constraint=None):
|
||||||
self._generate_expression([GAZ_BOOL_KEY],
|
self._generate_expression([GAZ_BOOL_KEY],
|
||||||
self.comp_op_numline,
|
self.comp_op_numline,
|
||||||
self.comp_op_cutoffs,
|
self.comp_op_cutoffs,
|
||||||
self.comp_op_options,
|
self.comp_op_options,
|
||||||
comparison=True) #, evals=self.get_truth())
|
comparison=True, eval_res=self.get_truth())
|
||||||
|
|
||||||
def push_scope(self, xml_element: ET.Element = None):
|
def push_scope(self, xml_element: ET.Element = None):
|
||||||
scope = Scope(self.current_scope)
|
scope = Scope(self.current_scope)
|
||||||
|
@ -689,7 +703,7 @@ class AstGenerator:
|
||||||
if res < cutoffs[i]:
|
if res < cutoffs[i]:
|
||||||
return ops[i] # TODO everything should be fast faied
|
return ops[i] # TODO everything should be fast faied
|
||||||
|
|
||||||
def get_value(self, type):
|
def get_value(self, type, constraint: tuple[str, str] | None = None):
|
||||||
if type == GAZ_INT_KEY:
|
if type == GAZ_INT_KEY:
|
||||||
if self.settings["properties"]["generate-max-int"]:
|
if self.settings["properties"]["generate-max-int"]:
|
||||||
return random.randint(-2147483648, 2147483647)
|
return random.randint(-2147483648, 2147483647)
|
||||||
|
@ -889,16 +903,16 @@ class AstGenerator:
|
||||||
break
|
break
|
||||||
break
|
break
|
||||||
|
|
||||||
def _generate_expr(self, comparison, expr_type, op, unary):
|
def _generate_expr(self, comparison, expr_type, op, unary, eval_res=None, constraint=None):
|
||||||
if op in unary:
|
if op in unary:
|
||||||
self.generate_unary(op, random.choice(expr_type))
|
self.generate_unary(op, random.choice(expr_type), constraint)
|
||||||
elif op == GAZ_BRACKET_TAG:
|
elif op == GAZ_BRACKET_TAG:
|
||||||
self.generate_bracket(random.choice(expr_type))
|
self.generate_bracket(random.choice(expr_type), constraint)
|
||||||
elif comparison:
|
elif comparison:
|
||||||
if op in ['equality', 'inequality']:
|
if op in ['equality', 'inequality']:
|
||||||
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
|
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]), eval_res, constraint)
|
||||||
else:
|
else:
|
||||||
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
|
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]), eval_res, constraint)
|
||||||
else:
|
else:
|
||||||
self.generate_binary(op, random.choice(expr_type))
|
self.generate_binary(op, random.choice(expr_type))
|
||||||
|
|
||||||
|
@ -923,3 +937,6 @@ class AstGenerator:
|
||||||
self.current_ast_element.append(var.xml)
|
self.current_ast_element.append(var.xml)
|
||||||
|
|
||||||
self.current_ast_element = parent
|
self.current_ast_element = parent
|
||||||
|
|
||||||
|
def get_truth(self):
|
||||||
|
return random.random() < self.settings['misc-weights']['conditional-true']
|
||||||
|
|
66
ast_generator/tiny_py_unparser.py
Normal file
66
ast_generator/tiny_py_unparser.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
from ast_parser.general_unparser import GeneralUnparser
|
||||||
|
from ast_parser.python_unparser import PythonUnparser, to_python_type, to_python_op
|
||||||
|
from constants import GAZ_TY_KEY, GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
|
||||||
|
class TinyPyUnparser(GeneralUnparser):
|
||||||
|
def __init__(self, ast: ET.Element, debug=False):
|
||||||
|
super().__init__(ast, debug,
|
||||||
|
endline='\n',
|
||||||
|
outstream_begin_delimiter="gprint(",
|
||||||
|
outstream_end_delimiter=", end='')",
|
||||||
|
function_return_type_indicator_predicate="->",
|
||||||
|
loop_start_delimiter="while ",
|
||||||
|
loop_end_delimiter=":",
|
||||||
|
conditional_case_delimiter="elif ",
|
||||||
|
conditional_start_delimiter="if ",
|
||||||
|
conditional_else_delimiter="else:",
|
||||||
|
conditional_end_delimiter=":",
|
||||||
|
block_start_delimiter="",
|
||||||
|
block_end_delimiter="", # TODO can this contain the pass?
|
||||||
|
strip_conditionals=True)
|
||||||
|
|
||||||
|
def format_variable(self, mut, ty, name, declaration: bool = False):
|
||||||
|
if declaration:
|
||||||
|
return "{}: {}".format(name, ty)
|
||||||
|
else:
|
||||||
|
return "{}".format(name)
|
||||||
|
|
||||||
|
def translate_value(self, val):
|
||||||
|
return str(val)
|
||||||
|
|
||||||
|
def translate_op(self, param, ty=None):
|
||||||
|
return to_python_op(param, ty)
|
||||||
|
|
||||||
|
def translate_type(self, ty):
|
||||||
|
return to_python_type(ty)
|
||||||
|
|
||||||
|
def function_declaration(self, xml_tag, args, name, return_type):
|
||||||
|
return "def {}{} {}:".format(
|
||||||
|
name,
|
||||||
|
args,
|
||||||
|
return_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_single_arg(self, ty, name):
|
||||||
|
return "{}: {}".format(name, ty)
|
||||||
|
|
||||||
|
def unparse_block(self, node):
|
||||||
|
# super().unparse_block(node)
|
||||||
|
self.source += f"{self.block_delimiters[0]}\n"
|
||||||
|
self.indentation += 4
|
||||||
|
for child in node:
|
||||||
|
self.unparse_node(child)
|
||||||
|
self.source += self.indentation_character * self.indentation + "pass\n"
|
||||||
|
self.indentation -= 4
|
||||||
|
if node.get(GAZ_TY_KEY) is None:
|
||||||
|
self.source += f"{self.block_delimiters[1]}\n\n"
|
||||||
|
elif node.get(GAZ_TY_KEY) in [GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG]:
|
||||||
|
self.source += f"{self.block_delimiters[1]}"
|
||||||
|
|
||||||
|
def unparse(self):
|
||||||
|
super().unparse()
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
pass
|
|
@ -90,9 +90,7 @@ misc-weights:
|
||||||
type-qualifier-weights:
|
type-qualifier-weights:
|
||||||
const: 10
|
const: 10
|
||||||
var: 60
|
var: 60
|
||||||
conditional-eval:
|
conditional-true: 0.9 # probability for a conditional to be true
|
||||||
true: 50
|
|
||||||
false: 50
|
|
||||||
|
|
||||||
block-termination-probability: 0.2 # probability for a block to terminate
|
block-termination-probability: 0.2 # probability for a block to terminate
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue