AstBuilder: custom float distribution

This commit is contained in:
Akemi Izuko 2023-11-16 21:24:05 -07:00
parent 2c37cae636
commit dfabbaf6f5
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC
2 changed files with 44 additions and 7 deletions

View file

@ -74,7 +74,6 @@ impl ToString for GlobalBlock {
for stat in &self.statements {
s.push_str(&stat.to_string());
s.push(' ');
}
s.push('\n');

View file

@ -22,15 +22,17 @@ pub struct AstBuilder {
ast: GlobalBlock,
name_counter: u64,
rng: rand::rngs::ThreadRng,
rng_float: FloatGenerator,
}
impl AstBuilder {
pub fn from(params: Params) -> Self {
Self {
params,
rng: rand::thread_rng(),
name_counter: 0,
ast: GlobalBlock::default(),
name_counter: 0,
rng: rand::thread_rng(),
rng_float: FloatGenerator::new(&params),
params,
}
}
@ -51,7 +53,7 @@ impl AstBuilder {
}
fn gen_decl(&mut self) -> Box<dyn Statement> {
let t = BaseType::Int;
let t = BaseType::Real;
let v = self.gen_variable_quantified(t, Quantifier::Const);
Box::new(
@ -94,11 +96,14 @@ impl AstBuilder {
fn gen_literal(&mut self, t: BaseType) -> Literal {
match t {
BaseType::Int => {
let r = rand_distr::Beta::new(0.5, 0.5).unwrap();
let r = rand_distr::Beta::new(0.01, 0.01).unwrap();
let i: i32 = (r.sample(&mut self.rng) * (i32::MAX as f64)) as i32;
Literal::Int(i)
}
BaseType::Real => Literal::Real(1.0),
BaseType::Real => {
let f = self.rng_float.sample(&mut self.rng);
Literal::Real(f)
}
BaseType::Never => panic!("Attempted to generate literal of type Never"),
BaseType::Unset => panic!("Attempted to generate literal of type Unset"),
}
@ -116,3 +121,36 @@ impl ToString for AstBuilder {
s
}
}
/// Generates values from regions of interest for a float
///
/// These are concentrated around f32::MIN, 0, and f32::MAX, though any f32 value is possible
struct FloatGenerator {
distro_pos: rand_distr::Normal<f32>,
distro_zero: rand_distr::Normal<f32>,
distro_neg: rand_distr::Normal<f32>,
}
impl FloatGenerator {
fn new(p: &Params) -> Self {
FloatGenerator {
distro_pos: rand_distr::Normal::new(1.0, p.dev.float_gen_distro_pos_stddiv).unwrap(),
distro_zero: rand_distr::Normal::new(0.0, 3.0).unwrap(),
distro_neg: rand_distr::Normal::new(1.0, p.dev.float_gen_distro_neg_stddiv).unwrap(),
}
}
}
impl Distribution<f32> for FloatGenerator {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
let p: f64 = rng.gen();
if p < 0.33 {
f32::MAX - self.distro_pos.sample(rng).abs()
} else if p < 0.66 {
self.distro_zero.sample(rng)
} else {
-1.0 * (f32::MAX - self.distro_neg.sample(rng).abs())
}
}
}