Compare commits

...

5 commits

Author SHA1 Message Date
Ayrton e2be922455 Fixed documentation for unary ops 2023-11-24 06:31:39 -07:00
Ayrton 5ea6eca0ba Fixed tests to account for unarys
- Refactored expressions
2023-11-23 13:40:20 -07:00
Ayrton 3dca7092fa Refactored the generate_expression method 2023-11-23 13:21:43 -07:00
Ayrton eba774dd05 Modified routine generation codestyle 2023-11-23 13:11:46 -07:00
Ayrton e0cd416435 Refactored generate_return 2023-11-23 13:01:13 -07:00
3 changed files with 337 additions and 264 deletions

View file

@ -1,10 +1,9 @@
import random
import string import string
import xml.etree.ElementTree as ET
from english_words import get_english_words_set from english_words import get_english_words_set
from ast_generator.utils import * from ast_generator.utils import *
from ast_generator.utils import filter_options, _choose_option
from constants import * from constants import *
import keyword import keyword
@ -25,6 +24,7 @@ class AstGenerator:
falls into will be selected. falls into will be selected.
""" """
### INITIALIZATION ###
def __init__(self, settings: dict): def __init__(self, settings: dict):
""" """
This class is designed to get the settings from some wrapper class that This class is designed to get the settings from some wrapper class that
@ -83,49 +83,13 @@ class AstGenerator:
self.variable_names = var_name_list[0:var_name_len // 2] self.variable_names = var_name_list[0:var_name_len // 2]
self.routine_names = var_name_list[var_name_len // 2:var_name_len] self.routine_names = var_name_list[var_name_len // 2:var_name_len]
### GENERATION ###
def generate_ast(self): def generate_ast(self):
""" """
@brief generates an AST from a grammar @brief generates an AST from a grammar
""" """
self.generate_top_level_block() self.generate_top_level_block()
def make_element(self, name: str, keys: list[tuple[str, any]]) -> ET.Element:
"""
@brief make an xml element for the ast
@effects modifies self.current_ast_element
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
element = build_xml_element(keys, name=name)
if self.current_ast_element is not None:
self.current_ast_element.append(element)
self.current_ast_element = element
return element
def make_scoped_element(self, name, keys) -> ET.Element:
"""
@brief make an xml element for the ast with a scope
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
parent = self.current_ast_element
self.push_scope()
self.make_element(name, keys)
return parent
def exit_scoped_element(self, parent):
"""
@brief leave the current element and return to parent
@param parent: the enclosing element to return to
"""
self.pop_scope()
self.current_ast_element = parent
def generate_top_level_block(self): def generate_top_level_block(self):
""" """
@brief creates the top-level block containing the whole program @brief creates the top-level block containing the whole program
@ -185,136 +149,71 @@ class AstGenerator:
self.pop_scope() self.pop_scope()
self.current_ast_element = parent self.current_ast_element = parent
def generate_loop_condition_check(self, loop_var: Variable):
"""
@brief generates the loop condition check
Ensures that the loop does not iterate more than max-loop-iterations times
@param loop_var:
@return:
"""
# loop var is always an int
assert loop_var.type == GAZ_INT_KEY
# create a conditional xml tag
if_stmt = build_xml_element([], name=GAZ_IF_TAG)
self.current_ast_element.append(if_stmt)
parent = self.current_ast_element
self.current_ast_element = if_stmt
# add the check 'if loop_var >= self.settings['generation_options']['max-loop-iterations']: break'
operation = build_xml_element([("op", ">=")], name=GAZ_OPERATOR_TAG)
self.current_ast_element.append(operation)
self.current_ast_element = operation
lhs = build_xml_element([], name=GAZ_LHS_TAG)
operation.append(lhs)
var = build_xml_element([("name", loop_var.name), ("type", loop_var.type)], name=GAZ_VAR_TAG)
lhs.append(var)
rhs = build_xml_element([], name=GAZ_RHS_TAG)
operation.append(rhs)
rhs.append(
self.make_literal(GAZ_INT_KEY, "'" + str(self.settings['generation-options']['max-loop-iterations']) + "'"))
true_block = build_xml_element([], name=GAZ_BLOCK_TAG)
if_stmt.append(true_block)
self.current_ast_element = true_block
break_stmt = build_xml_element([], name=GAZ_BREAK_TAG)
true_block.append(break_stmt)
# return everything to normalcy
self.current_ast_element = parent
def generate_loop_condition_increment(self, loop_var):
assert loop_var.type == GAZ_INT_KEY
parent = self.current_ast_element
assignment = build_xml_element([], name=GAZ_ASSIGNMENT_TAG)
self.current_ast_element.append(assignment)
self.current_ast_element = assignment
# append the variable
self.current_ast_element.append(loop_var.xml)
# add the increment 'loop_var += 1'
assn_rhs = build_xml_element([], name=GAZ_RHS_TAG)
self.current_ast_element.append(assn_rhs)
self.current_ast_element = assn_rhs
operation = build_xml_element([("op", "+")], name=GAZ_OPERATOR_TAG)
self.current_ast_element.append(operation)
self.current_ast_element = operation
lhs = build_xml_element([], name=GAZ_LHS_TAG)
operation.append(lhs)
var = build_xml_element([("name", loop_var.name), ("type", loop_var.type)], name=GAZ_VAR_TAG)
lhs.append(var)
rhs = build_xml_element([], name=GAZ_RHS_TAG)
operation.append(rhs)
rhs.append(self.make_literal(GAZ_INT_KEY, '1'))
# return everything to normalcy
self.current_ast_element = parent
def generate_return(self, return_type=None, return_value=None): def generate_return(self, return_type=None, return_value=None):
"""
@brief generates a return statement
@param return_type: the type to be returned (if None -> any)
@param return_value: value to be returned (if None -> expr[return_type])
"""
if return_type is None or return_type == GAZ_VOID_TYPE: if return_type is None or return_type == GAZ_VOID_TYPE:
self.current_ast_element.append(build_xml_element([], name=GAZ_RETURN_TAG)) self.current_ast_element.append(self.make_element(GAZ_RETURN_TAG, []))
return
else: else:
# store the parent
parent = self.current_ast_element
# initialize element
keys = [("type", return_type)]
self.make_element(GAZ_RETURN_TAG, keys)
# make either a literal or a random expression based on choice
if return_value is None: if return_value is None:
xml_element = build_xml_element([("type", return_type)], name=GAZ_RETURN_TAG)
self.current_ast_element.append(xml_element)
parent = self.current_ast_element
self.current_ast_element = xml_element
self.generate_expression(return_type) self.generate_expression(return_type)
self.current_ast_element = parent
return
else: else:
xml_element = build_xml_element([("type", return_type)], name=GAZ_RETURN_TAG)
self.current_ast_element.append(xml_element)
parent = self.current_ast_element
self.current_ast_element = xml_element
self.current_ast_element.append(self.make_literal(return_type, return_value)) self.current_ast_element.append(self.make_literal(return_type, return_value))
# return to the parent
self.current_ast_element = parent self.current_ast_element = parent
return
def generate_routine(self, routine_type=None): def generate_routine(self, routine_type=None):
"""
@brief generate a random routine
@param return_type: the type to be returned (if None -> any (including void))
"""
if routine_type is None: if routine_type is None:
routine_type = self.get_routine_type() routine_type = self.get_routine_type() # get a random type
else: else:
routine_type = routine_type pass
# initialize random variables
args = self.generate_routine_args() args = self.generate_routine_args()
name = self.get_name(routine_type) name = self.get_name(routine_type)
return_type = self.get_type(routine_type) return_type = self.get_type(routine_type)
# initialize the routine
routine = Routine(name, routine_type, return_type, args) routine = Routine(name, routine_type, return_type, args)
routine_args = [ routine_args = [
("name", routine.name), ("name", routine.name),
("return_type", routine.return_type), ("return_type", routine.return_type),
] ]
element = build_xml_element(routine_args, name=routine.type) # Generation
self.current_ast_element.append(element)
parent = self.current_ast_element parent = self.current_ast_element
self.current_ast_element = element self.make_scoped_element(routine.type, routine_args)
self.push_scope()
self.define_args(routine.arguments) self.define_args(routine.arguments)
self.generate_block(return_stmt=True, return_type=routine.return_type) self.generate_block(return_stmt=True, return_type=routine.return_type)
self.pop_scope()
self.current_ast_element = parent self.exit_scoped_element(parent)
def define_args(self, args): def define_args(self, args):
"""
@brief Generate the argument tags in a routine
@param args: a list of arguments
"""
for arg in args: for arg in args:
self.current_ast_element.append(arg.xml) self.current_ast_element.append(arg.xml)
self.current_scope.append(arg.name, arg) self.current_scope.append(arg.name, arg)
@ -336,184 +235,145 @@ class AstGenerator:
6: self.generate_in_stream, 6: self.generate_in_stream,
} }
if include is not None and exclude is not None: # Filter unwanted options
raise ValueError("Cannot specify both include and exclude") filter_options(exclude, include, options, opts)
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
while True: # Generate the statements
if random.random() < self.settings['block-termination-probability']: self._generate_from_options(cutoffs, number_line, options)
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY],
self.int_op_numline,
self.int_op_cutoffs,
self.int_op_options,
self.int_unary)
def generate_float_expr(self):
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
self.float_op_numline,
self.float_op_cutoffs,
self.float_op_options,
self.float_unary)
def generate_bool_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.bool_op_numline,
self.bool_op_cutoffs,
self.bool_op_options,
self.bool_unary)
def generate_char_expr(self):
self._generate_expression([GAZ_CHAR_KEY],
self.char_op_numline,
self.char_op_cutoffs,
self.char_op_options)
def generate_comp_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.comp_op_numline,
self.comp_op_cutoffs,
self.comp_op_options,
comparison=True)
def _generate_expression(self, expr_type: list[str], number_line, def _generate_expression(self, expr_type: list[str], number_line,
cutoffs, options, unary=None, comparison: bool = False): cutoffs, options, unary=None, comparison: bool = False):
"""
@brief Generate an expression
@param expr_type: a list of types to be used
@param number_line: number line for probability computation
@param cutoffs: cutoffs to be used
@param options: options to be used
@param unary: a list of unary operators in options
"""
if unary is None: if unary is None:
unary = [] unary = []
parent = self.current_ast_element parent = self.current_ast_element
self.current_nesting_depth += 1 self.current_nesting_depth += 1
# Check the expression depth against settings
if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \ if self.current_nesting_depth > self.settings['generation-options']['max-nesting-depth'] or random.random() < \
self.settings['block-termination-probability']: self.settings['block-termination-probability']:
self.generate_literal(random.choice(expr_type)) self.generate_literal(random.choice(expr_type))
self.current_nesting_depth -= 1 self.current_nesting_depth -= 1
return return
op = "" # Generate
a = random.randint(0, number_line - 1) op = _choose_option(cutoffs, number_line, options)
i = 0 self._generate_expr(comparison, expr_type, op, unary)
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
if op in unary:
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))
# Return to parent
self.current_nesting_depth -= 1 self.current_nesting_depth -= 1
self.current_ast_element = parent self.current_ast_element = parent
def generate_declaration(self, mut=None): def generate_declaration(self, mut=None): # TODO change this to a bool
"""
@brief Generate a declaration
@param mut: the mutability of the variable ('const' or 'var')
"""
# Initialize the variable
parent = self.current_ast_element parent = self.current_ast_element
decl_type = self.get_type(GAZ_VAR_TAG) decl_type = self.get_type(GAZ_VAR_TAG)
decl_args = [ decl_args = [
("type", decl_type), ("type", decl_type),
] ]
element = build_xml_element(decl_args, name=GAZ_DECLARATION_TAG) self.make_element(GAZ_DECLARATION_TAG, decl_args)
self.current_ast_element.append(element)
self.current_ast_element = element
# Generate the variable
variable = self.generate_variable(decl_type, mut=mut) variable = self.generate_variable(decl_type, mut=mut)
self.current_ast_element.append(variable.xml) self.current_ast_element.append(variable.xml)
self.current_scope.append(variable.name, variable) self.current_scope.append(variable.name, variable) # make sure the variable is in scope
self.generate_xhs(GAZ_RHS_TAG, decl_type) # TODO add real type (decl_type) # Generate the initialization of the variable
self.generate_xhs(GAZ_RHS_TAG, decl_type)
# Return to parent
self.current_ast_element = parent self.current_ast_element = parent
def generate_binary(self, op, op_type): def generate_binary(self, op, op_type):
"""
@brief Generate a binary operation
@param op: the operator
@param op_type: the type of the expression
"""
parent = self.current_ast_element parent = self.current_ast_element
# Check if the operator is valid
if op == "": if op == "":
raise ValueError("op is empty!") raise ValueError("op is empty!")
args = [ args = [
("op", op), ("op", op),
("type", op_type), ("type", op_type),
] ]
element = build_xml_element(args, name=GAZ_OPERATOR_TAG) self.make_element(GAZ_OPERATOR_TAG, args)
self.current_ast_element.append(element)
self.current_ast_element = element
# Gnereate lhs and rhs
self.generate_xhs(GAZ_LHS_TAG, op_type) self.generate_xhs(GAZ_LHS_TAG, op_type)
self.generate_xhs(GAZ_RHS_TAG, op_type) self.generate_xhs(GAZ_RHS_TAG, op_type)
# Return to parent
self.current_ast_element = parent self.current_ast_element = parent
def generate_bracket(self, op_type): def generate_bracket(self, op_type):
parent = self.current_ast_element """
args = [ @brief Generate a bracket operation
("type", op_type),
]
element = build_xml_element(args, name=GAZ_BRACKET_TAG)
self.current_ast_element.append(element)
self.current_ast_element = element
@param op_type: the type of the expression
"""
parent = self.current_ast_element
args = [("type", op_type)]
self.make_element(GAZ_BRACKET_TAG, args)
# Generate the expression in the brackets
self.generate_xhs(GAZ_RHS_TAG, op_type) self.generate_xhs(GAZ_RHS_TAG, op_type)
# Return to parent
self.current_ast_element = parent self.current_ast_element = parent
def generate_xhs(self, handedness, op_type, is_zero=False): def generate_xhs(self, handedness, op_type, is_zero=False):
element = build_xml_element([], name=handedness) """
@brief generate a lhs or a rhs depending on handedness
@param handedness: the handedness
@param op_type: the type of the expression
@param is_zero: if the expression is zero
"""
parent = self.current_ast_element parent = self.current_ast_element
self.current_ast_element.append(element)
self.current_ast_element = element self.make_element(handedness, [])
self.generate_expression(op_type, is_zero=is_zero) self.generate_expression(op_type, is_zero=is_zero)
self.current_ast_element = parent self.current_ast_element = parent
def generate_unary(self, op, op_type=ANY_TYPE): def generate_unary(self, op, op_type=ANY_TYPE):
"""
@brief Generate a unary operation
@param op_type: the type of the expression
"""
parent = self.current_ast_element parent = self.current_ast_element
args = [ args = [
("op", op), ("op", op),
("type", op_type), ("type", op_type),
] ]
element = build_xml_element(args, name=GAZ_UNARY_OPERATOR_TAG) self.make_element(GAZ_UNARY_OPERATOR_TAG, args)
self.current_ast_element.append(element)
self.current_ast_element = element
self.generate_xhs(GAZ_RHS_TAG, op_type) self.generate_xhs(GAZ_RHS_TAG, op_type)
self.current_ast_element = parent self.current_ast_element = parent
def generate_routine_call(self): def generate_routine_call(self): # we should generate a test case with arbitrary number of args
pass pass
def generate_conditional(self): def generate_conditional(self):
@ -687,6 +547,40 @@ class AstGenerator:
def generate_arg(self): def generate_arg(self):
return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG)) return Argument(self.get_name(GAZ_VAR_TAG), self.get_type(GAZ_VAR_TAG))
def generate_int_expr(self):
self._generate_expression([GAZ_INT_KEY],
self.int_op_numline,
self.int_op_cutoffs,
self.int_op_options,
self.int_unary)
def generate_float_expr(self):
self._generate_expression([GAZ_FLOAT_KEY, GAZ_INT_KEY],
self.float_op_numline,
self.float_op_cutoffs,
self.float_op_options,
self.float_unary)
def generate_bool_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.bool_op_numline,
self.bool_op_cutoffs,
self.bool_op_options,
self.bool_unary)
def generate_char_expr(self):
self._generate_expression([GAZ_CHAR_KEY],
self.char_op_numline,
self.char_op_cutoffs,
self.char_op_options)
def generate_comp_expr(self):
self._generate_expression([GAZ_BOOL_KEY],
self.comp_op_numline,
self.comp_op_cutoffs,
self.comp_op_options,
comparison=True)
def push_scope(self, xml_element: ET.Element = None): def push_scope(self, xml_element: ET.Element = None):
scope = Scope(self.current_scope) scope = Scope(self.current_scope)
self.symbol_table.append(scope) self.symbol_table.append(scope)
@ -783,3 +677,141 @@ class AstGenerator:
for i in range(len(cutoffs)): for i in range(len(cutoffs)):
if res < cutoffs[i]: if res < cutoffs[i]:
return types[i] return types[i]
### LOOP HELPERS ###
def generate_loop_condition_check(self, loop_var: Variable):
"""
@brief generates the loop condition check
Ensures that the loop does not iterate more than max-loop-iterations times
@param loop_var:
@return:
"""
# loop var is always an int
assert loop_var.type == GAZ_INT_KEY
# create a conditional xml tag
if_stmt = build_xml_element([], name=GAZ_IF_TAG)
self.current_ast_element.append(if_stmt)
parent = self.current_ast_element
self.current_ast_element = if_stmt
# add the check 'if loop_var >= self.settings['generation_options']['max-loop-iterations']: break'
operation = build_xml_element([("op", ">=")], name=GAZ_OPERATOR_TAG)
rhs = self._loop_heloper(loop_var, operation)
rhs.append(
self.make_literal(GAZ_INT_KEY, "'" + str(self.settings['generation-options']['max-loop-iterations']) + "'"))
true_block = build_xml_element([], name=GAZ_BLOCK_TAG)
if_stmt.append(true_block)
self.current_ast_element = true_block
break_stmt = build_xml_element([], name=GAZ_BREAK_TAG)
true_block.append(break_stmt)
# return everything to normalcy
self.current_ast_element = parent
def _loop_heloper(self, loop_var, operation):
self.current_ast_element.append(operation)
self.current_ast_element = operation
lhs = build_xml_element([], name=GAZ_LHS_TAG)
operation.append(lhs)
var = build_xml_element([("name", loop_var.name), ("type", loop_var.type)], name=GAZ_VAR_TAG)
lhs.append(var)
rhs = build_xml_element([], name=GAZ_RHS_TAG)
operation.append(rhs)
return rhs
def generate_loop_condition_increment(self, loop_var):
assert loop_var.type == GAZ_INT_KEY
parent = self.current_ast_element
assignment = build_xml_element([], name=GAZ_ASSIGNMENT_TAG)
self.current_ast_element.append(assignment)
self.current_ast_element = assignment
# append the variable
self.current_ast_element.append(loop_var.xml)
# add the increment 'loop_var += 1'
assn_rhs = build_xml_element([], name=GAZ_RHS_TAG)
self.current_ast_element.append(assn_rhs)
self.current_ast_element = assn_rhs
operation = build_xml_element([("op", "+")], name=GAZ_OPERATOR_TAG)
rhs = self._loop_heloper(loop_var, operation)
rhs.append(self.make_literal(GAZ_INT_KEY, '1'))
# return everything to normalcy
self.current_ast_element = parent
### HELPER FUNCTIONS ###
def make_element(self, name: str, keys: list[tuple[str, any]]) -> ET.Element:
"""
@brief make an xml element for the ast
@effects modifies self.current_ast_element
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
element = build_xml_element(keys, name=name)
if self.current_ast_element is not None:
self.current_ast_element.append(element)
self.current_ast_element = element
return element
def make_scoped_element(self, name, keys) -> ET.Element:
"""
@brief make an xml element for the ast with a scope
@param name: the tag for the element
@param keys: a list of tuple containing keys for the element
"""
parent = self.current_ast_element
self.push_scope()
self.make_element(name, keys)
return parent
def exit_scoped_element(self, parent):
"""
@brief leave the current element and return to parent
@param parent: the enclosing element to return to
"""
self.pop_scope()
self.current_ast_element = parent
def _generate_from_options(self, cutoffs, number_line, options):
while True:
if random.random() < self.settings['block-termination-probability']:
break
a = random.randint(0, number_line)
i = 0
for i in range(len(cutoffs) - 1):
if cutoffs[i] < a < cutoffs[i + 1]:
try:
options[i]()
except KeyError:
continue
except ValueError:
break
break
def _generate_expr(self, comparison, expr_type, op, unary):
if op in unary:
self.generate_unary(op, random.choice(expr_type))
elif op == GAZ_BRACKET_TAG:
self.generate_bracket(random.choice(expr_type))
elif comparison:
if op in ['equality', 'inequality']:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY, GAZ_CHAR_KEY]))
else:
self.generate_binary(op, random.choice([GAZ_INT_KEY, GAZ_FLOAT_KEY]))
else:
self.generate_binary(op, random.choice(expr_type))

View file

@ -60,7 +60,7 @@ class TestGeneration(unittest.TestCase):
def test_generate_assignment(self): def test_generate_assignment(self):
self.ast_gen.ast = ET.Element("block") self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast self.ast_gen.current_ast_element = self.ast_gen.ast
self.ast_gen.generate_declaration() self.ast_gen.generate_declaration(mut='var')
self.ast_gen.generate_assignment() self.ast_gen.generate_assignment()
self.assertIsNotNone(self.ast_gen.ast.find("assignment")) self.assertIsNotNone(self.ast_gen.ast.find("assignment"))
@ -132,17 +132,19 @@ class TestGeneration(unittest.TestCase):
print(ET.tostring(conditional, 'utf-8').decode('utf-8')) print(ET.tostring(conditional, 'utf-8').decode('utf-8'))
self.has_child(conditional)
block = conditional.findall("block")
self.assertEqual(2, len(block))
def has_child(self, conditional):
opts = ['operator', 'unary_operator', 'literal', 'brackets'] opts = ['operator', 'unary_operator', 'literal', 'brackets']
res = [] res = []
for i in opts: for i in opts:
res.append(conditional.find(i)) res.append(conditional.find(i))
res_list = list(filter(lambda x: x is not None, res)) res_list = list(filter(lambda x: x is not None, res))
self.assertGreater(len(res_list), 0) self.assertGreater(len(res_list), 0)
block = conditional.findall("block")
self.assertEqual(2, len(block))
def test_generate_loop(self): def test_generate_loop(self):
self.ast_gen.ast = ET.Element("block") self.ast_gen.ast = ET.Element("block")
self.ast_gen.current_ast_element = self.ast_gen.ast self.ast_gen.current_ast_element = self.ast_gen.ast
@ -153,7 +155,7 @@ class TestGeneration(unittest.TestCase):
# print(ET.tostring(loop, 'utf-8').decode('utf-8')) # print(ET.tostring(loop, 'utf-8').decode('utf-8'))
self.assertIsNotNone(loop.find("operator") or loop.find("unary_operator") or loop.find("literal")) self.has_child(loop)
block = loop.findall("block") block = loop.findall("block")
self.assertEqual(1, len(block)) self.assertEqual(1, len(block))
@ -290,6 +292,12 @@ class TestGeneration(unittest.TestCase):
else: else:
lhs = operator.find("lhs") lhs = operator.find("lhs")
rhs = operator.find("rhs") rhs = operator.find("rhs")
if lhs is None:
if rhs.find("operator") is not None:
res = self.is_no_op(rhs.find("operator"))
elif rhs.find("unary") is not None:
res = self.is_no_op(rhs.find("unary"))
else:
if lhs.find("operator") is not None: if lhs.find("operator") is not None:
res = self.is_no_op(lhs.find("operator")) res = self.is_no_op(lhs.find("operator"))
elif lhs.find("unary") is not None: elif lhs.find("unary") is not None:
@ -297,7 +305,7 @@ class TestGeneration(unittest.TestCase):
elif rhs.find("operator") is not None: elif rhs.find("operator") is not None:
res = self.is_no_op(rhs.find("operator")) res = self.is_no_op(rhs.find("operator"))
elif rhs.find("unary") is not None: elif rhs.find("unary") is not None:
res = self.is_no_op(lhs.find("unary")) res = self.is_no_op(rhs.find("unary"))
return res return res

View file

@ -1,3 +1,4 @@
import random
from xml.etree import ElementTree as ET from xml.etree import ElementTree as ET
from constants import GAZ_VAR_TAG, GAZ_ARG_TAG from constants import GAZ_VAR_TAG, GAZ_ARG_TAG
@ -146,3 +147,35 @@ def get_numberlines(settings_section: str, subsettings: list[str], excluded_valu
raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int") raise TypeError("invalid setting type. Found " + str(v) + " instead of expected int")
return options, cutoffs, number_line return options, cutoffs, number_line
def filter_options(exclude, include, options, opts):
if include is not None and exclude is not None:
raise ValueError("Cannot specify both include and exclude")
elif include is not None and include in opts:
for i in range(len(opts)):
if opts[i] in include:
continue
else:
options.pop(opts.index(opts[i]))
elif exclude is not None and exclude in opts:
options.pop(opts.index(exclude))
elif include is None and exclude is None:
pass
else:
raise ValueError("Invalid include/exclude options " + str(include) + " " + str(exclude))
def _choose_option(cutoffs, number_line, options):
op = ""
a = random.randint(0, number_line - 1)
i = 0
for i in range(len(cutoffs) - 1):
if i == 0:
if a < cutoffs[i]:
op = options[i]
break
if cutoffs[i] <= a < cutoffs[i + 1]:
op = options[i]
break
return op