Refactored ASTGenerator #3

Merged
aCompetentBean merged 8 commits from ayrton/refactor into main 2023-11-24 07:34:47 -07:00
2 changed files with 73 additions and 55 deletions
Showing only changes of commit 3dca7092fa - Show all commits

View file

@ -1,10 +1,9 @@
import random
import string import string
import xml.etree.ElementTree as ET
from english_words import get_english_words_set from english_words import get_english_words_set
from ast_generator.utils import * from ast_generator.utils import *
from ast_generator.utils import filter_options, _choose_option
from constants import * from constants import *
import keyword import keyword
@ -236,36 +235,11 @@ class AstGenerator:
6: self.generate_in_stream, 6: self.generate_in_stream,
} }
if include is not None and exclude is not None: # Filter unwanted options
raise ValueError("Cannot specify both include and exclude") filter_options(exclude, include, options, opts)
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
while True: # Generate the statements
if random.random() < self.settings['block-termination-probability']: self._generate_from_options(cutoffs, number_line, options)
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
def generate_int_expr(self): def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY], self._generate_expression([GAZ_INT_KEY],
@ -315,29 +289,9 @@ class AstGenerator:
self.current_nesting_depth -= 1 self.current_nesting_depth -= 1
return return
op = "" op = _choose_option(cutoffs, number_line, options)
a = random.randint(0, number_line - 1)
i = 0
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
if op in unary: self._generate_expr(comparison, expr_type, op, unary)
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))
self.current_nesting_depth -= 1 self.current_nesting_depth -= 1
self.current_ast_element = parent self.current_ast_element = parent
@ -789,4 +743,35 @@ class AstGenerator:
@param parent: the enclosing element to return to @param parent: the enclosing element to return to
""" """
self.pop_scope() self.pop_scope()
self.current_ast_element = parent self.current_ast_element = parent
def _generate_from_options(self, cutoffs, number_line, options):
while True:
if random.random() < self.settings['block-termination-probability']:
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
def _generate_expr(self, comparison, expr_type, op, unary):
if op in unary:
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))

View file

@ -1,3 +1,4 @@
import random
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from constants import GAZ_VAR_TAG, GAZ_ARG_TAG from constants import GAZ_VAR_TAG, GAZ_ARG_TAG
@ -145,4 +146,36 @@ def get_numberlines(settings_section: str, subsettings: list[str], excluded_valu
else: else:
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int") raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line return options, cutoffs, number_line
def filter_options(exclude, include, options, opts):
if include is not None and exclude is not None:
raise ValueError("Cannot specify both include and exclude")
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
def _choose_option(cutoffs, number_line, options):
op = ""
a = random.randint(0, number_line - 1)
i = 0
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
return op