diff --git a/src/ast_builder.rs b/src/ast_builder.rs index 5ba676e..464f101 100644 --- a/src/ast_builder.rs +++ b/src/ast_builder.rs @@ -4,6 +4,7 @@ use rand_distr::Distribution; use std::collections::HashMap; use crate::params::Params; +use crate::params::a_lt; use crate::ast::*; #[derive(Default)] @@ -49,17 +50,12 @@ impl AstBuilder { } pub fn generate(&mut self) { - let mut p: f64; - - let p1 = self.params.global_flow.end_generation; - let p2 = p1 + self.params.global_flow.gen_typedef; - loop { - p = self.rng.gen(); + let mut p: f64 = self.rng.gen(); - if p < p1 { + if a_lt(&mut p, self.params.global_flow.end_generation) { break; - } else if p < p2 { + } else if a_lt(&mut p, self.params.global_flow.gen_typedef) { let typedef = self.gen_typedef(); self.ast.push(typedef); } else { @@ -70,13 +66,13 @@ impl AstBuilder { } fn gen_decl(&mut self) -> Stat { - let p: f64 = self.rng.gen(); + let mut p: f64 = self.rng.gen(); let ps = &self.params.statements.gen_decl; let t: BaseType; let assn: Expr; - if p < ps.gen_integer { + if a_lt(&mut p, ps.gen_integer) { t = BaseType::Int; assn = self.gen_integer(); } else { @@ -127,12 +123,12 @@ impl AstBuilder { } fn gen_integer(&mut self) -> Expr { - let p: f64 = self.rng.gen(); + let mut p: f64 = self.rng.gen(); let mut ps = &self.params.types.gen_integer; self.state.int_recursion_depth += 1; - if p < ps.get_instant || self.state.int_recursion_depth >= ps.max_depth { + if a_lt(&mut p, ps.get_instant) || self.state.int_recursion_depth >= ps.max_depth { self.state.int_recursion_depth -= 1; return Expr::new_literal(self.gen_literal(BaseType::Int)); } @@ -143,13 +139,13 @@ impl AstBuilder { ps = &self.params.types.gen_integer; - if p < ps.get_instant + ps.expr_add { + if a_lt(&mut p, ps.expr_add) { Expr::new_binary_op(BinaryOperator::add(lhs, rhs)) - } else if p < ps.get_instant + ps.expr_add + ps.expr_sub { + } else if a_lt(&mut p, ps.expr_sub) { Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs)) - } else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul { + } else if a_lt(&mut p, ps.expr_mul) { Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs)) - } else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div { + } else if a_lt(&mut p, ps.expr_div) { Expr::new_binary_op(BinaryOperator::divide(lhs, rhs)) } else { Expr::new_literal(self.gen_literal(BaseType::Int)) // TODO: this shouldn't be here @@ -157,15 +153,15 @@ impl AstBuilder { } fn gen_real(&mut self) -> Expr { - let p: f64 = self.rng.gen(); + let mut p: f64 = self.rng.gen(); let mut ps = &self.params.types.gen_real; self.state.real_recursion_depth += 1; - if p < ps.get_instant || self.state.real_recursion_depth >= ps.max_depth { + if a_lt(&mut p, ps.get_instant) || self.state.real_recursion_depth >= ps.max_depth { self.state.real_recursion_depth -= 1; return Expr::new_literal(self.gen_literal(BaseType::Real)); - } else if p < ps.get_instant + ps.gen_integer { + } else if a_lt(&mut p, ps.gen_integer) { self.state.real_recursion_depth -= 1; return self.gen_integer(); } @@ -175,15 +171,14 @@ impl AstBuilder { self.state.real_recursion_depth -= 1; ps = &self.params.types.gen_real; - let prevp = ps.get_instant + ps.gen_integer; - if p < prevp + ps.expr_add { + if a_lt(&mut p, ps.expr_add) { Expr::new_binary_op(BinaryOperator::add(lhs, rhs)) - } else if p < prevp + ps.expr_add + ps.expr_sub { + } else if a_lt(&mut p, ps.expr_sub) { Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs)) - } else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul { + } else if a_lt(&mut p, ps.expr_mul) { Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs)) - } else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div { + } else if a_lt(&mut p, ps.expr_div) { Expr::new_binary_op(BinaryOperator::divide(lhs, rhs)) } else { Expr::new_literal(self.gen_literal(BaseType::Real)) // TODO: this shouldn't be here diff --git a/src/params.rs b/src/params.rs index a9fc2da..8c8ad7a 100644 --- a/src/params.rs +++ b/src/params.rs @@ -3,6 +3,13 @@ use std::path::PathBuf; use serde::Deserialize; +/// Additive less than. Updates probability after comparing +pub fn a_lt(p: &mut f64, additive: f64) -> bool { + let is_lt = *p < additive; + *p -= additive; + is_lt +} + // Makes all these structs public and deserializable macro_rules! toml_struct { (struct $name:ident {$($field:ident: $t:ty,)*}) => {