Refactored ASTGenerator #3

Merged
aCompetentBean merged 8 commits from ayrton/refactor into main 2023-11-24 07:34:47 -07:00
2 changed files with 73 additions and 55 deletions
Showing only changes of commit 3dca7092fa - Show all commits

View file

@ -1,10 +1,9 @@
import random
import string
import xml.etree.ElementTree as ET
from english_words import get_english_words_set
from ast_generator.utils import *
from ast_generator.utils import filter_options, _choose_option
from constants import *
import keyword
@ -236,36 +235,11 @@ class AstGenerator:
6: self.generate_in_stream,
}
if include is not None and exclude is not None:
raise ValueError("Cannot specify both include and exclude")
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
# Filter unwanted options
filter_options(exclude, include, options, opts)
while True:
if random.random() < self.settings['block-termination-probability']:
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
# Generate the statements
self._generate_from_options(cutoffs, number_line, options)
def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY],
@ -315,29 +289,9 @@ class AstGenerator:
self.current_nesting_depth -= 1
return
op = ""
a = random.randint(0, number_line - 1)
i = 0
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
op = _choose_option(cutoffs, number_line, options)
if op in unary:
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))
self._generate_expr(comparison, expr_type, op, unary)
self.current_nesting_depth -= 1
self.current_ast_element = parent
@ -790,3 +744,34 @@ class AstGenerator:
"""
self.pop_scope()
self.current_ast_element = parent
def _generate_from_options(self, cutoffs, number_line, options):
while True:
if random.random() < self.settings['block-termination-probability']:
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
def _generate_expr(self, comparison, expr_type, op, unary):
if op in unary:
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))

View file

@ -1,3 +1,4 @@
import random
from xml.etree import ElementTree as ET
from constants import GAZ_VAR_TAG, GAZ_ARG_TAG
@ -146,3 +147,35 @@ def get_numberlines(settings_section: str, subsettings: list[str], excluded_valu
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line
def filter_options(exclude, include, options, opts):
if include is not None and exclude is not None:
raise ValueError("Cannot specify both include and exclude")
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
def _choose_option(cutoffs, number_line, options):
op = ""
a = random.randint(0, number_line - 1)
i = 0
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
return op