AstParams: rolling probability sums
All checks were successful
ci/woodpecker/push/build_rust Pipeline was successful

This commit is contained in:
Akemi Izuko 2023-11-18 01:12:52 -07:00
parent bb34acfcfd
commit 1652ca66ff
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
2 changed files with 26 additions and 24 deletions

View file

@ -4,6 +4,7 @@ use rand_distr::Distribution;
use std::collections::HashMap;
use crate::params::Params;
use crate::params::a_lt;
use crate::ast::*;
#[derive(Default)]
@ -49,17 +50,12 @@ impl AstBuilder {
}
pub fn generate(&mut self) {
let mut p: f64;
let p1 = self.params.global_flow.end_generation;
let p2 = p1 + self.params.global_flow.gen_typedef;
loop {
p = self.rng.gen();
let mut p: f64 = self.rng.gen();
if p < p1 {
if a_lt(&mut p, self.params.global_flow.end_generation) {
break;
} else if p < p2 {
} else if a_lt(&mut p, self.params.global_flow.gen_typedef) {
let typedef = self.gen_typedef();
self.ast.push(typedef);
} else {
@ -70,13 +66,13 @@ impl AstBuilder {
}
fn gen_decl(&mut self) -> Stat {
let p: f64 = self.rng.gen();
let mut p: f64 = self.rng.gen();
let ps = &self.params.statements.gen_decl;
let t: BaseType;
let assn: Expr;
if p < ps.gen_integer {
if a_lt(&mut p, ps.gen_integer) {
t = BaseType::Int;
assn = self.gen_integer();
} else {
@ -127,12 +123,12 @@ impl AstBuilder {
}
fn gen_integer(&mut self) -> Expr {
let p: f64 = self.rng.gen();
let mut p: f64 = self.rng.gen();
let mut ps = &self.params.types.gen_integer;
self.state.int_recursion_depth += 1;
if p < ps.get_instant || self.state.int_recursion_depth >= ps.max_depth {
if a_lt(&mut p, ps.get_instant) || self.state.int_recursion_depth >= ps.max_depth {
self.state.int_recursion_depth -= 1;
return Expr::new_literal(self.gen_literal(BaseType::Int));
}
@ -143,13 +139,13 @@ impl AstBuilder {
ps = &self.params.types.gen_integer;
if p < ps.get_instant + ps.expr_add {
if a_lt(&mut p, ps.expr_add) {
Expr::new_binary_op(BinaryOperator::add(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub {
} else if a_lt(&mut p, ps.expr_sub) {
Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul {
} else if a_lt(&mut p, ps.expr_mul) {
Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div {
} else if a_lt(&mut p, ps.expr_div) {
Expr::new_binary_op(BinaryOperator::divide(lhs, rhs))
} else {
Expr::new_literal(self.gen_literal(BaseType::Int)) // TODO: this shouldn't be here
@ -157,15 +153,15 @@ impl AstBuilder {
}
fn gen_real(&mut self) -> Expr {
let p: f64 = self.rng.gen();
let mut p: f64 = self.rng.gen();
let mut ps = &self.params.types.gen_real;
self.state.real_recursion_depth += 1;
if p < ps.get_instant || self.state.real_recursion_depth >= ps.max_depth {
if a_lt(&mut p, ps.get_instant) || self.state.real_recursion_depth >= ps.max_depth {
self.state.real_recursion_depth -= 1;
return Expr::new_literal(self.gen_literal(BaseType::Real));
} else if p < ps.get_instant + ps.gen_integer {
} else if a_lt(&mut p, ps.gen_integer) {
self.state.real_recursion_depth -= 1;
return self.gen_integer();
}
@ -175,15 +171,14 @@ impl AstBuilder {
self.state.real_recursion_depth -= 1;
ps = &self.params.types.gen_real;
let prevp = ps.get_instant + ps.gen_integer;
if p < prevp + ps.expr_add {
if a_lt(&mut p, ps.expr_add) {
Expr::new_binary_op(BinaryOperator::add(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub {
} else if a_lt(&mut p, ps.expr_sub) {
Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul {
} else if a_lt(&mut p, ps.expr_mul) {
Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div {
} else if a_lt(&mut p, ps.expr_div) {
Expr::new_binary_op(BinaryOperator::divide(lhs, rhs))
} else {
Expr::new_literal(self.gen_literal(BaseType::Real)) // TODO: this shouldn't be here

View file

@ -3,6 +3,13 @@ use std::path::PathBuf;
use serde::Deserialize;
/// Additive less than. Updates probability after comparing
pub fn a_lt(p: &mut f64, additive: f64) -> bool {
let is_lt = *p < additive;
*p -= additive;
is_lt
}
// Makes all these structs public and deserializable
macro_rules! toml_struct {
(struct $name:ident {$($field:ident: $t:ty,)*}) => {