Added Brackets

Took 1 hour 13 minutes
This commit is contained in:
ayrton 2023-11-21 20:40:50 -07:00
parent 1e22f5a968
commit 561a9a5efa
4 changed files with 86 additions and 8 deletions

View file

@ -51,6 +51,57 @@ class AstGenerator:
self.current_nesting_depth = 0 self.current_nesting_depth = 0
self.current_control_flow_nesting_depth = 0 self.current_control_flow_nesting_depth = 0
# Numberlines - For computing probabilities
self.int_op_options, self.int_op_cutoffs, self.int_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'arithmetic', 'unary'],
[[], [], ['negation']]))
self.int_unary = ['negation', 'noop']
pass
def get_numberlines(self, settings_section: str, subsettings: list[str], excluded_values: list[list[str or None]]):
assert len(subsettings) == len(excluded_values)
number_line = 0
cutoffs = []
cutoff = 0
options = {}
option = 0
settings = []
for key, value in self.settings[settings_section].items():
if key in subsettings and key not in excluded_values: # this check needs to be done recursively
if isinstance(value, int):
t = {
key: value
}
settings.append(t)
elif isinstance(value, dict):
settings.append(value)
else:
raise TypeError("invalid setting type. Found " + str(value) + " instead of expected int or dict")
for v in settings:
if isinstance(v, dict):
for key, value in v.items():
number_line += value
cutoffs.append(cutoff + value)
cutoff += value
options[option] = key
option += 1
elif isinstance(v, int):
number_line += v
cutoffs.append(cutoff + v)
cutoff += v
options[option] = v
option += 1
else:
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line
def generate_ast(self): def generate_ast(self):
""" """
@brief generates an AST from a grammar @brief generates an AST from a grammar
@ -214,7 +265,8 @@ class AstGenerator:
unary = ["negation", "noop"] unary = ["negation", "noop"]
self._generate_expression([GAZ_INT_KEY, GAZ_FLOAT_KEY], number_line, cutoffs, options, unary) self._generate_expression([GAZ_INT_KEY, GAZ_FLOAT_KEY], self.int_op_numline, self.int_op_cutoffs,
self.int_op_options, self.int_unary)
def generate_bool_expr(self): def generate_bool_expr(self):
# Number line # Number line
@ -259,6 +311,8 @@ class AstGenerator:
if op in unary: if op in unary:
self.generate_unary(op, random.choice(expr_type)) self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
else: else:
self.generate_binary(op, random.choice(expr_type)) self.generate_binary(op, random.choice(expr_type))
@ -299,6 +353,19 @@ class AstGenerator:
self.current_ast_element = parent self.current_ast_element = parent
def generate_bracket(self, op_type):
parent = self.current_ast_element
args = [
("type", op_type),
]
element = build_xml_element(args, name=GAZ_BRACKET_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
self.generate_xhs(GAZ_RHS_TAG, op_type)
self.current_ast_element = parent
def generate_xhs(self, handedness, op_type): def generate_xhs(self, handedness, op_type):
element = build_xml_element([], name=handedness) element = build_xml_element([], name=handedness)
parent = self.current_ast_element parent = self.current_ast_element
@ -667,7 +734,7 @@ def build_xml_element(*keys, name):
def get_op(op): def get_op(op):
if op == 'addition': if op == 'addition' or 'noop':
return '+' return '+'
elif op == 'subtraction': elif op == 'subtraction':
return '-' return '-'
@ -699,8 +766,6 @@ def get_op(op):
return '-' return '-'
elif op == 'not': elif op == 'not':
return 'not' return 'not'
elif op == 'noop':
return '+'
elif op == 'concatenation': elif op == 'concatenation':
return '||' return '||'
else: else:

View file

@ -72,7 +72,7 @@ class GeneralUnparser:
self.unparse_node(node) self.unparse_node(node)
def unparse_node(self, node): def unparse_node(self, node):
if node.tag not in [GAZ_VAR_TAG, GAZ_RHS_TAG, GAZ_LHS_TAG, GAZ_LIT_TAG, GAZ_OPERATOR_TAG]: if node.tag not in [GAZ_VAR_TAG, GAZ_RHS_TAG, GAZ_LHS_TAG, GAZ_LIT_TAG, GAZ_OPERATOR_TAG, GAZ_BRACKET_TAG]:
self.source += self.indentation_character * self.indentation self.source += self.indentation_character * self.indentation
if node.tag == GAZ_BLOCK_TAG: if node.tag == GAZ_BLOCK_TAG:
@ -103,15 +103,22 @@ class GeneralUnparser:
self.unparse_conditional(node) self.unparse_conditional(node)
elif node.tag == GAZ_LOOP_TAG: elif node.tag == GAZ_LOOP_TAG:
self.unparse_loop(node) self.unparse_loop(node)
elif node.tag == GAZ_BRACKET_TAG:
self.unparse_brackets(node)
else: else:
raise Exception("Unknown tag: " + node.tag) raise Exception("Unknown tag: " + node.tag)
def unparse_block(self, node): def unparse_block(self, node):
self.source += f"{self.indentation * self.indentation_character}{self.block_delimiters[0]}\n" if node.get(GAZ_TY_KEY) is None:
self.source += f"{self.indentation * self.indentation_character}"
self.source += f"{self.block_delimiters[0]}\n"
self.indentation += 4 self.indentation += 4
for child in node: for child in node:
self.unparse_node(child) self.unparse_node(child)
self.indentation -= 4 self.indentation -= 4
if node.get(GAZ_TY_KEY) is None: if node.get(GAZ_TY_KEY) is None:
self.source += f"{self.indentation * self.indentation_character}{self.block_delimiters[1]}\n\n" self.source += f"{self.indentation * self.indentation_character}{self.block_delimiters[1]}\n\n"
elif node.get(GAZ_TY_KEY) in [GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG]: elif node.get(GAZ_TY_KEY) in [GAZ_TRUE_BLOCK_TAG, GAZ_FALSE_BLOCK_TAG]:
@ -271,6 +278,11 @@ class GeneralUnparser:
self.source += "{}".format(self.translate_op(element_in.get("op"))) self.source += "{}".format(self.translate_op(element_in.get("op")))
self.unparse_xhs(element_in.find(GAZ_RHS_TAG)) self.unparse_xhs(element_in.find(GAZ_RHS_TAG))
def unparse_brackets(self, element_in: ET.Element):
self.source += "("
self.unparse_xhs(element_in.find(GAZ_RHS_TAG))
self.source += ")"
def unparse_single_arg(self, param): def unparse_single_arg(self, param):
return self.format_single_arg(self.translate_type(param.get(GAZ_TY_KEY)), param.get(GAZ_NAME_KEY)) return self.format_single_arg(self.translate_type(param.get(GAZ_TY_KEY)), param.get(GAZ_NAME_KEY))

View file

@ -24,13 +24,13 @@ def to_python_type(ty):
def to_python_op(param): def to_python_op(param):
if param == "negation" or param == "subtraction": if param == "negation" or param == "subtraction":
return "-" return "-"
elif param == "addition": elif param == "addition" or param == "noop":
return "+" return "+"
elif param == "multiplication": elif param == "multiplication":
return "*" return "*"
elif param == "division": elif param == "division":
return "/" return "/"
elif param == "modulus": elif param == "modulo":
return "%" return "%"
elif param == "power": elif param == "power":
return "**" return "**"

View file

@ -45,3 +45,4 @@ GAZ_FALSE_BLOCK_TAG = "false"
GAZ_ARG_TAG = "argument" GAZ_ARG_TAG = "argument"
GAZ_STRING_KEY = "string" GAZ_STRING_KEY = "string"
GAZ_CHAR_KEY = "char" GAZ_CHAR_KEY = "char"
GAZ_BRACKET_TAG = "brackets"