AstParams: rolling probability sums
All checks were successful
ci/woodpecker/push/build_rust Pipeline was successful

This commit is contained in:
Akemi Izuko 2023-11-18 01:12:52 -07:00
parent bb34acfcfd
commit 1652ca66ff
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
2 changed files with 26 additions and 24 deletions

View file

@ -4,6 +4,7 @@ use rand_distr::Distribution;
use std::collections::HashMap; use std::collections::HashMap;
use crate::params::Params; use crate::params::Params;
use crate::params::a_lt;
use crate::ast::*; use crate::ast::*;
#[derive(Default)] #[derive(Default)]
@ -49,17 +50,12 @@ impl AstBuilder {
} }
pub fn generate(&mut self) { pub fn generate(&mut self) {
let mut p: f64;
let p1 = self.params.global_flow.end_generation;
let p2 = p1 + self.params.global_flow.gen_typedef;
loop { loop {
p = self.rng.gen(); let mut p: f64 = self.rng.gen();
if p < p1 { if a_lt(&mut p, self.params.global_flow.end_generation) {
break; break;
} else if p < p2 { } else if a_lt(&mut p, self.params.global_flow.gen_typedef) {
let typedef = self.gen_typedef(); let typedef = self.gen_typedef();
self.ast.push(typedef); self.ast.push(typedef);
} else { } else {
@ -70,13 +66,13 @@ impl AstBuilder {
} }
fn gen_decl(&mut self) -> Stat { fn gen_decl(&mut self) -> Stat {
let p: f64 = self.rng.gen(); let mut p: f64 = self.rng.gen();
let ps = &self.params.statements.gen_decl; let ps = &self.params.statements.gen_decl;
let t: BaseType; let t: BaseType;
let assn: Expr; let assn: Expr;
if p < ps.gen_integer { if a_lt(&mut p, ps.gen_integer) {
t = BaseType::Int; t = BaseType::Int;
assn = self.gen_integer(); assn = self.gen_integer();
} else { } else {
@ -127,12 +123,12 @@ impl AstBuilder {
} }
fn gen_integer(&mut self) -> Expr { fn gen_integer(&mut self) -> Expr {
let p: f64 = self.rng.gen(); let mut p: f64 = self.rng.gen();
let mut ps = &self.params.types.gen_integer; let mut ps = &self.params.types.gen_integer;
self.state.int_recursion_depth += 1; self.state.int_recursion_depth += 1;
if p < ps.get_instant || self.state.int_recursion_depth >= ps.max_depth { if a_lt(&mut p, ps.get_instant) || self.state.int_recursion_depth >= ps.max_depth {
self.state.int_recursion_depth -= 1; self.state.int_recursion_depth -= 1;
return Expr::new_literal(self.gen_literal(BaseType::Int)); return Expr::new_literal(self.gen_literal(BaseType::Int));
} }
@ -143,13 +139,13 @@ impl AstBuilder {
ps = &self.params.types.gen_integer; ps = &self.params.types.gen_integer;
if p < ps.get_instant + ps.expr_add { if a_lt(&mut p, ps.expr_add) {
Expr::new_binary_op(BinaryOperator::add(lhs, rhs)) Expr::new_binary_op(BinaryOperator::add(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub { } else if a_lt(&mut p, ps.expr_sub) {
Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs)) Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul { } else if a_lt(&mut p, ps.expr_mul) {
Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs)) Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs))
} else if p < ps.get_instant + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div { } else if a_lt(&mut p, ps.expr_div) {
Expr::new_binary_op(BinaryOperator::divide(lhs, rhs)) Expr::new_binary_op(BinaryOperator::divide(lhs, rhs))
} else { } else {
Expr::new_literal(self.gen_literal(BaseType::Int)) // TODO: this shouldn't be here Expr::new_literal(self.gen_literal(BaseType::Int)) // TODO: this shouldn't be here
@ -157,15 +153,15 @@ impl AstBuilder {
} }
fn gen_real(&mut self) -> Expr { fn gen_real(&mut self) -> Expr {
let p: f64 = self.rng.gen(); let mut p: f64 = self.rng.gen();
let mut ps = &self.params.types.gen_real; let mut ps = &self.params.types.gen_real;
self.state.real_recursion_depth += 1; self.state.real_recursion_depth += 1;
if p < ps.get_instant || self.state.real_recursion_depth >= ps.max_depth { if a_lt(&mut p, ps.get_instant) || self.state.real_recursion_depth >= ps.max_depth {
self.state.real_recursion_depth -= 1; self.state.real_recursion_depth -= 1;
return Expr::new_literal(self.gen_literal(BaseType::Real)); return Expr::new_literal(self.gen_literal(BaseType::Real));
} else if p < ps.get_instant + ps.gen_integer { } else if a_lt(&mut p, ps.gen_integer) {
self.state.real_recursion_depth -= 1; self.state.real_recursion_depth -= 1;
return self.gen_integer(); return self.gen_integer();
} }
@ -175,15 +171,14 @@ impl AstBuilder {
self.state.real_recursion_depth -= 1; self.state.real_recursion_depth -= 1;
ps = &self.params.types.gen_real; ps = &self.params.types.gen_real;
let prevp = ps.get_instant + ps.gen_integer;
if p < prevp + ps.expr_add { if a_lt(&mut p, ps.expr_add) {
Expr::new_binary_op(BinaryOperator::add(lhs, rhs)) Expr::new_binary_op(BinaryOperator::add(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub { } else if a_lt(&mut p, ps.expr_sub) {
Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs)) Expr::new_binary_op(BinaryOperator::subtract(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul { } else if a_lt(&mut p, ps.expr_mul) {
Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs)) Expr::new_binary_op(BinaryOperator::multiply(lhs, rhs))
} else if p < prevp + ps.expr_add + ps.expr_sub + ps.expr_mul + ps.expr_div { } else if a_lt(&mut p, ps.expr_div) {
Expr::new_binary_op(BinaryOperator::divide(lhs, rhs)) Expr::new_binary_op(BinaryOperator::divide(lhs, rhs))
} else { } else {
Expr::new_literal(self.gen_literal(BaseType::Real)) // TODO: this shouldn't be here Expr::new_literal(self.gen_literal(BaseType::Real)) // TODO: this shouldn't be here

View file

@ -3,6 +3,13 @@ use std::path::PathBuf;
use serde::Deserialize; use serde::Deserialize;
/// Additive less than. Updates probability after comparing
pub fn a_lt(p: &mut f64, additive: f64) -> bool {
let is_lt = *p < additive;
*p -= additive;
is_lt
}
// Makes all these structs public and deserializable // Makes all these structs public and deserializable
macro_rules! toml_struct { macro_rules! toml_struct {
(struct $name:ident {$($field:ident: $t:ty,)*}) => { (struct $name:ident {$($field:ident: $t:ty,)*}) => {