Refactored ASTGenerator #3

Merged
aCompetentBean merged 8 commits from ayrton/refactor into main 2023-11-24 07:34:47 -07:00
7 changed files with 139 additions and 151 deletions
Showing only changes of commit 7380a89082 - Show all commits

View file

@ -4,7 +4,7 @@ import xml.etree.ElementTree as ET
from english_words import get_english_words_set
from ast_generator.utils import Variable, Argument, Routine, Scope, build_xml_element
from ast_generator.utils import *
from constants import *
import keyword
@ -41,98 +41,47 @@ class AstGenerator:
self.symbol_table.append(global_scope) # NOTE for debug
self.current_scope = global_scope
names = get_english_words_set(['web2'], alpha=True)
possible_names = filter(lambda x: self.settings['properties']['id-length']['max'] <= len(x) <=
self.settings['properties']['id-length']['max'] and not keyword.iskeyword(x),
names)
var_name_list = list(possible_names)
var_name_len = len(var_name_list)
self.variable_names = var_name_list[0:var_name_len // 2]
self.routine_names = var_name_list[var_name_len // 2:var_name_len]
self._init_names()
self.ast: ET.Element or None = None
self.current_ast_element: ET.Element or None = None
self.current_nesting_depth = 0
self.current_control_flow_nesting_depth = 0
self._init_numlines()
def _init_numlines(self):
# Numberlines - For computing probabilities
self.int_op_options, self.int_op_cutoffs, self.int_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'arithmetic', 'unary'],
[[], [], ['not']]))
get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], [], ['not']],
self.settings))
self.int_unary = ['negation', 'noop']
self.bool_op_options, self.bool_op_cutoffs, self.bool_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'comparison', 'logical', 'unary'],
get_numberlines('expression-weights', ['brackets', 'comparison', 'logical', 'unary'],
excluded_values=[[], ['less-than-or-equal', 'greater-than-or-equal', 'less-than',
'greater-than'], [], ['noop', 'negation']]))
'greater-than'], [], ['noop', 'negation']],
settings=self.settings))
self.bool_unary = ['not']
self.float_op_options, self.float_op_cutoffs, self.float_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'arithmetic', 'unary'],
[[], [], ['not']]))
get_numberlines('expression-weights', ['brackets', 'arithmetic', 'unary'], [[], [], ['not']],
self.settings))
self.float_unary = ['negation', 'noop']
self.char_op_options, self.char_op_cutoffs, self.char_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'comparison'],
[[], ['less-than', 'greater-than', 'less-than-or-equal', 'greater-than-or-equal']]))
get_numberlines('expression-weights', ['brackets', 'comparison'],
[[], ['less-than', 'greater-than', 'less-than-or-equal', 'greater-than-or-equal']],
self.settings))
self.comp_op_options, self.comp_op_cutoffs, self.comp_op_numline = (
self.get_numberlines('expression-weights',
['brackets', 'comparison'],
[[], []]))
get_numberlines('expression-weights', ['brackets', 'comparison'], [[], []], self.settings))
def get_numberlines(self, settings_section: str, subsettings: list[str], excluded_values):
assert len(subsettings) == len(excluded_values)
number_line = 0
cutoffs = []
cutoff = 0
options = {}
option = 0
settings = []
for key, value in self.settings[settings_section].items():
if key in subsettings and key not in excluded_values: # this check needs to be done recursively
if isinstance(value, int):
t = {
key: value
}
settings.append(t)
elif isinstance(value, dict):
settings.append(value)
else:
raise TypeError("invalid setting type. Found " + str(value) + " instead of expected int or dict")
for v in range(len(settings)):
for i in excluded_values:
for j in i:
if j in settings[v]:
settings[v].pop(j)
for v in settings:
if isinstance(v, dict):
for key, value in v.items():
number_line += value
cutoffs.append(cutoff + value)
cutoff += value
options[option] = key
option += 1
elif isinstance(v, int):
number_line += v
cutoffs.append(cutoff + v)
cutoff += v
options[option] = v
option += 1
else:
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line
def _init_names(self):
names = get_english_words_set(['web2'], alpha=True)
possible_names = filter(lambda x: self.settings['properties']['id-length']['max'] <= len(x) <=
self.settings['properties']['id-length']['max'] and not keyword.iskeyword(x),
names)
var_name_list = list(possible_names)
var_name_len = len(var_name_list)
self.variable_names = var_name_list[0:var_name_len // 2]
self.routine_names = var_name_list[var_name_len // 2:var_name_len]
def generate_ast(self):
"""
@ -140,37 +89,71 @@ class AstGenerator:
"""
self.generate_top_level_block()
def generate_top_level_block(self): # TODO add constant generation into this block
i = 0
def make_element(self, name: str, keys: list[tuple[str, any]]) -> ET.Element:
"""
@brief make an xml element for the ast
element = build_xml_element([], name=GAZ_BLOCK_TAG)
@effects modifies self.current_ast_element
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
element = build_xml_element(keys, name=name)
if self.current_ast_element is not None:
self.current_ast_element.append(element)
self.current_ast_element = element
return element
def make_scoped_element(self, name, keys) -> ET.Element:
"""
@brief make an xml element for the ast with a scope
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
parent = self.current_ast_element
self.push_scope()
self.make_element(name, keys)
return parent
def exit_scoped_element(self, parent):
"""
@brief leave the current element and return to parent
@param parent: the enclosing element to return to
"""
self.pop_scope()
self.current_ast_element = parent
def generate_top_level_block(self):
"""
@brief creates the top-level block containing the whole program
"""
element = self.make_element(GAZ_BLOCK_TAG, [])
self.ast = element
# TODO generate constants and forward declarations
while i < self.settings['generation-options']['max-number-of-routines']:
for i in range(random.randint(0, self.settings['generation-options']['max-globals'])):
self.generate_global()
for i in range(self.settings['generation-options']['max-number-of-routines']):
if random.random() < self.settings['block-termination-probability']:
break
self.generate_routine()
i += 1
self.generate_main()
pass
def generate_main(self):
parent = self.current_ast_element
self.push_scope()
main_args = [ # TODO refactor these into constants
("name", "main"),
("return_type", GAZ_INT_KEY),
("args", "()"),
(GAZ_NAME_KEY, "main"),
(GAZ_RETURN_KEY, GAZ_INT_KEY),
]
element = build_xml_element(main_args, name=GAZ_PROCEDURE_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
parent = self.make_scoped_element(GAZ_PROCEDURE_TAG, main_args)
self.generate_block(return_stmt=True, return_value="0", return_type=GAZ_INT_KEY, block_type=GAZ_PROCEDURE_TAG)
self.pop_scope()
self.current_ast_element = parent
self.exit_scoped_element(parent)
def generate_block(self, tag=None, return_stmt=False, return_value=None, return_type=None, block_type=None,
loop_var=None):

View file

@ -1,53 +0,0 @@
from constants import Grammar
GAZPREA_TOP_LEVEL: Grammar = {
# Top level elements
'<start>': ['<topBlock>'],
'<topBlock>': ['<XML_OPEN_TAG>block<XML_CLOSE_TAG><routine_list><main_routine><routine_list><XML_OPEN_SLASH>block<XML_CLOSE_TAG>'],
# TODO constants
# Routines
'<routine>': ['<function>', '<procedure>'], # TODO forward_declaration
'<function>': [
'<XML_OPEN_TAG>function name="_NAME_" return_type="_TYPE_" args="_ARGS_"<XML_CLOSE_TAG><return_block><XML_OPEN_SLASH>function<XML_CLOSE_TAG>'],
'<procedure>': [
'<XML_OPEN_TAG>procedure name="_NAME_" return_type="_TYPE_" args="_ARGS_"<XML_CLOSE_TAG><block><XML_OPEN_SLASH>procedure<XML_CLOSE_TAG>'],
'<main_routine>': [
'<XML_OPEN_TAG>procedure name="main" return_type="int" args="()"<XML_CLOSE_TAG><return_block><XML_OPEN_SLASH>procedure<XML_CLOSE_TAG>'],
'<routine_list>': ['<routine><routine_list><routine>', '<routine>'],
# Blocks
'<block>': ['<XML_OPEN_TAG>block<XML_CLOSE_TAG><statement_list><XML_OPEN_SLASH>block<XML_CLOSE_TAG>'],
'<return_block>': ['<XML_OPEN_TAG>block<XML_CLOSE_TAG><statement_list><return><XML_OPEN_SLASH>block<XML_CLOSE_TAG>'],
'<statement>': [
'<declaration>',
'<stream>',
# '<call>',
# '<return>', # TODO if/else, loop
],
'<statement_list>': ['<statement><statement_list><statement>', '<statement>'],
# Things that belong on their own lines
'<declaration>': ['<XML_OPEN_TAG>declaration<XML_CLOSE_TAG><variable><rhs><XML_OPEN_SLASH>declaration<XML_CLOSE_TAG>'],
'<stream>': ['<out_stream>'], #, '<in_stream>'],
'<return>': ['<XML_OPEN_TAG>return<XML_CLOSE_TAG><has_value><XML_OPEN_SLASH>return<XML_CLOSE_TAG>'],
'<out_stream>': ['<XML_OPEN_TAG>stream type="std_output"<XML_CLOSE_TAG><has_value><XML_OPEN_SLASH>stream<XML_CLOSE_TAG>'],
# '<in_stream>': ['<XML_OPEN_TAG>stream type="std_input"<XML_CLOSE_TAG><has_value><XML_OPEN_SLASH>stream<XML_CLOSE_TAG>'],
# Things that are part of lines
'<has_value>': ['<variable>', '<literal>', '<operator>'],
'<lhs>': ['<XML_OPEN_TAG>lhs<XML_CLOSE_TAG><has_value><XML_OPEN_SLASH>lhs<XML_CLOSE_TAG>'],
'<rhs>': ['<XML_OPEN_TAG>rhs<XML_CLOSE_TAG><has_value><XML_OPEN_SLASH>rhs<XML_CLOSE_TAG>'],
# Things that have values
'<operator>': ['<XML_OPEN_TAG>operator<XML_CLOSE_TAG><lhs><rhs><XML_OPEN_SLASH>operator<XML_CLOSE_TAG>'],
'<variable>': ['<XML_OPEN_TAG>variable mut="_MODIFIER_" type="_TYPE_" name="_NAME_"<XML_SLASH_TAG>'],
'<literal>': ['<XML_OPEN_TAG>literal type="_TYPE_" value="_VALUE_"<XML_SLASH_TAG>'],
# Helper rules
'<XML_OPEN_TAG>': ['<'],
'<XML_CLOSE_TAG>': ['>'],
'<XML_SLASH_TAG>': ['/>'],
'<XML_OPEN_SLASH>': ['</'],
}

View file

@ -5,19 +5,21 @@ generation-options:
max-conditionals-loops: 5 # maximum number of loops/conditionals per routine
max-number-of-routines: 5 # maximum number of routines (main will always be generated)
generate-dead-code: True # generate dead code
max-loop-iterations: 100 # maximum number of iterations in a loop
max-globals: 5 # maximum number of global variables
properties:
max-range-length: 5 # maximum length of ranges, vectors and tuples, (AxA matrices can exist)
use-english-words: True # use english words instead of random names (this may limit the maximum number of names)
id-length: # length of identifiers
min: 1
max: 10
max: 5
function-name-length: # length of function names
min: 1
max: 10
number-of-arguments: # number of arguments to a routine
min: 1
max: 10
generate-max-int: True # if False, generate integers between [-1000, 1000] else
generate-max-int: False # if False, generate integers between [-1000, 1000] else
expression-weights: # weights for expressions
# the higher a weight, the more likely (0, 10000), 0 to exclude, 10000 for only that
brackets: 10

View file

@ -130,9 +130,15 @@ class TestGeneration(unittest.TestCase):
self.assertIsNotNone(self.ast_gen.current_ast_element.find("conditional"))
conditional = self.ast_gen.ast.find("conditional")
# print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
self.assertIsNotNone(conditional.find("operator") or conditional.find("unary_operator") or conditional.find("literal"))
opts = ['operator', 'unary_operator', 'literal', 'brackets']
res = []
for i in opts:
res.append(conditional.find(i))
res_list = list(filter(lambda x: x is not None, res))
self.assertGreater(len(res_list), 0)
block = conditional.findall("block")
self.assertEqual(2, len(block))
@ -211,7 +217,7 @@ class TestGeneration(unittest.TestCase):
self.assertIsNotNone(self.ast_gen.ast)
# print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
print(ET.tostring(self.ast_gen.ast, 'utf-8').decode('utf-8'))
procedures = self.ast_gen.ast.findall("procedure")
self.assertLess(0, len(procedures))

View file

@ -97,3 +97,52 @@ def build_xml_element(*keys, name):
for key in list(keys)[0]: # TODO refactor
elem.set(key[0], key[1])
return elem
def get_numberlines(settings_section: str, subsettings: list[str], excluded_values, settings):
assert len(subsettings) == len(excluded_values)
number_line = 0
cutoffs = []
cutoff = 0
options = {}
option = 0
valid_settings = []
for key, value in settings[settings_section].items():
if key in subsettings and key not in excluded_values: # this check needs to be done recursively
if isinstance(value, int):
t = {
key: value
}
valid_settings.append(t)
elif isinstance(value, dict):
valid_settings.append(value)
else:
raise TypeError("invalid setting type. Found " + str(value) + " instead of expected int or dict")
for v in range(len(valid_settings)):
for i in excluded_values:
for j in i:
if j in valid_settings[v]:
valid_settings[v].pop(j)
for v in valid_settings:
if isinstance(v, dict):
for key, value in v.items():
number_line += value
cutoffs.append(cutoff + value)
cutoff += value
options[option] = key
option += 1
elif isinstance(v, int):
number_line += v
cutoffs.append(cutoff + v)
cutoff += v
options[option] = v
option += 1
else:
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line

View file

@ -6,6 +6,7 @@ generation-options:
max-number-of-routines: 5 # maximum number of routines (main will always be generated)
generate-dead-code: True # generate dead code
max-loop-iterations: 100 # maximum number of iterations in a loop
max-globals: 5 # maximum number of global variables
properties:
max-range-length: 5 # maximum length of ranges, vectors and tuples, (AxA matrices can exist)
use-english-words: True # use english words instead of random names (this may limit the maximum number of names)

View file