diff --git a/ast_generator/ast_generator.py b/ast_generator/ast_generator.py index cf650ed..fbd09c1 100644 --- a/ast_generator/ast_generator.py +++ b/ast_generator/ast_generator.py @@ -1,10 +1,9 @@ -import random import string -import xml.etree.ElementTree as ET from english_words import get_english_words_set from ast_generator.utils import * +from ast_generator.utils import filter_options, _choose_option from constants import * import keyword @@ -236,36 +235,11 @@ class AstGenerator: 6: self.generate_in_stream, } - if include is not None and exclude is not None: - raise ValueError("Cannot specify both include and exclude") - elif include is not None and include in opts: - for i in range(len(opts)): - if opts[i] in include: - continue - else: - options.pop(opts.index(opts[i])) - elif exclude is not None and exclude in opts: - options.pop(opts.index(exclude)) - elif include is None and exclude is None: - pass - else: - raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude)) + # Filter unwanted options + filter_options(exclude, include, options, opts) - while True: - if random.random() < self.settings['block-termination-probability']: - break - - a = random.randint(0, number_line) - i = 0 - for i in range(len(cutoffs) - 1): - if cutoffs[i] < a < cutoffs[i + 1]: - try: - options[i]() - except KeyError: - continue - except ValueError: - break - break + # Generate the statements + self._generate_from_options(cutoffs, number_line, options) def generate_int_expr(self): self._generate_expression([GAZ_INT_KEY], @@ -315,29 +289,9 @@ class AstGenerator: self.current_nesting_depth -= 1 return - op = "" - a = random.randint(0, number_line - 1) - i = 0 - for i in range(len(cutoffs) - 1): - if i == 0: - if a < cutoffs[i]: - op = options[i] - break - if cutoffs[i] <= a < cutoffs[i + 1]: - op = options[i] - break + op = _choose_option(cutoffs, number_line, options) - if op in unary: - self.generate_unary(op, random.choice(expr_type)) - elif op == GAZ_BRACKET_TAG: - self.generate_bracket(random.choice(expr_type)) - elif comparison: - if op in ['equality', 'inequality']: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY])) - else: - self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY])) - else: - self.generate_binary(op, random.choice(expr_type)) + self._generate_expr(comparison, expr_type, op, unary) self.current_nesting_depth -= 1 self.current_ast_element = parent @@ -789,4 +743,35 @@ class AstGenerator: @param parent: the enclosing element to return to """ self.pop_scope() - self.current_ast_element = parent \ No newline at end of file + self.current_ast_element = parent + + def _generate_from_options(self, cutoffs, number_line, options): + while True: + if random.random() < self.settings['block-termination-probability']: + break + + a = random.randint(0, number_line) + i = 0 + for i in range(len(cutoffs) - 1): + if cutoffs[i] < a < cutoffs[i + 1]: + try: + options[i]() + except KeyError: + continue + except ValueError: + break + break + + def _generate_expr(self, comparison, expr_type, op, unary): + if op in unary: + self.generate_unary(op, random.choice(expr_type)) + elif op == GAZ_BRACKET_TAG: + self.generate_bracket(random.choice(expr_type)) + elif comparison: + if op in ['equality', 'inequality']: + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY])) + else: + self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY])) + else: + self.generate_binary(op, random.choice(expr_type)) + diff --git a/ast_generator/utils.py b/ast_generator/utils.py index b739816..a4d723f 100644 --- a/ast_generator/utils.py +++ b/ast_generator/utils.py @@ -1,3 +1,4 @@ +import random from xml.etree import ElementTree as ET from constants import GAZ_VAR_TAG, GAZ_ARG_TAG @@ -145,4 +146,36 @@ def get_numberlines(settings_section: str, subsettings: list[str], excluded_valu else: raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int") - return options, cutoffs, number_line \ No newline at end of file + return options, cutoffs, number_line + + +def filter_options(exclude, include, options, opts): + if include is not None and exclude is not None: + raise ValueError("Cannot specify both include and exclude") + elif include is not None and include in opts: + for i in range(len(opts)): + if opts[i] in include: + continue + else: + options.pop(opts.index(opts[i])) + elif exclude is not None and exclude in opts: + options.pop(opts.index(exclude)) + elif include is None and exclude is None: + pass + else: + raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude)) + + +def _choose_option(cutoffs, number_line, options): + op = "" + a = random.randint(0, number_line - 1) + i = 0 + for i in range(len(cutoffs) - 1): + if i == 0: + if a < cutoffs[i]: + op = options[i] + break + if cutoffs[i] <= a < cutoffs[i + 1]: + op = options[i] + break + return op