From 16a9632f9d5950c031b396d17c3d311661ef2000 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Fri, 17 Nov 2023 14:36:34 -0700 Subject: [PATCH] Ast: traits for basetype --- src/ast/expr.rs | 129 +++++++++++++++++++++++++++++++++++++++--------- src/ast/mod.rs | 29 ++++++++++- 2 files changed, 133 insertions(+), 25 deletions(-) diff --git a/src/ast/expr.rs b/src/ast/expr.rs index 0bf9b79..c8180df 100644 --- a/src/ast/expr.rs +++ b/src/ast/expr.rs @@ -70,7 +70,6 @@ impl GazType for Literal { } } - #[derive(Clone, Default, Builder)] #[builder(setter(into))] pub struct Variable { @@ -102,41 +101,129 @@ impl GazType for Variable { } #[derive(Clone)] -pub enum BinaryOperator { +enum BinaryOp { Add(Expr, Expr), Subtract(Expr, Expr), Multiply(Expr, Expr), Divide(Expr, Expr), } +#[derive(Clone)] +pub struct BinaryOperator { + basetype: BaseType, + op: BinaryOp, +} + +impl BinaryOperator { + pub fn add(left: Expr, right: Expr) -> Self { + let is_both_same = left.get_base() == right.get_base(); + let is_one_real = left.get_base() == BaseType::Real || right.get_base() == BaseType::Real; + let is_one_int = left.get_base() == BaseType::Int || right.get_base() == BaseType::Int; + + let basetype = if is_both_same { + left.get_base() + } else if is_one_real && is_one_int { + BaseType::Real + } else { + panic!("Unsupported types being added: {:?} and {:?}", + left.get_base(), right.get_base()); + }; + + BinaryOperator { + basetype, + op: BinaryOp::Add(left, right), + } + } + + pub fn sub(left: Expr, right: Expr) -> Self { + let is_both_same = left.get_base() == right.get_base(); + let is_one_real = left.get_base() == BaseType::Real || right.get_base() == BaseType::Real; + let is_one_int = left.get_base() == BaseType::Int || right.get_base() == BaseType::Int; + + let basetype = if is_both_same { + left.get_base() + } else if is_one_real && is_one_int { + BaseType::Real + } else { + panic!("Unsupported types being subtracted: {:?} and {:?}", + left.get_base(), right.get_base()); + }; + + BinaryOperator { + basetype, + op: BinaryOp::Subtract(left, right), + } + } + + pub fn multiply(left: Expr, right: Expr) -> Self { + let is_both_same = left.get_base() == right.get_base(); + let is_one_real = left.get_base() == BaseType::Real || right.get_base() == BaseType::Real; + let is_one_int = left.get_base() == BaseType::Int || right.get_base() == BaseType::Int; + + let basetype = if is_both_same { + left.get_base() + } else if is_one_real && is_one_int { + BaseType::Real + } else { + panic!("Unsupported types being multiplied: {:?} and {:?}", + left.get_base(), right.get_base()); + }; + + BinaryOperator { + basetype, + op: BinaryOp::Multiply(left, right), + } + } + + pub fn divide(left: Expr, right: Expr) -> Self { + let is_both_same = left.get_base() == right.get_base(); + let is_one_real = left.get_base() == BaseType::Real || right.get_base() == BaseType::Real; + let is_one_int = left.get_base() == BaseType::Int || right.get_base() == BaseType::Int; + + let basetype = if is_both_same { + left.get_base() + } else if is_one_real && is_one_int { + BaseType::Real + } else { + panic!("Unsupported types being divided: {:?} and {:?}", + left.get_base(), right.get_base()); + }; + + BinaryOperator { + basetype, + op: BinaryOp::Divide(left, right), + } + } +} + impl ToString for BinaryOperator { fn to_string(&self) -> String { let mut s = String::new(); s.push('('); - match self { - BinaryOperator::Add(l, _) => s.push_str(&l.to_string()), - BinaryOperator::Subtract(l, _) => s.push_str(&l.to_string()), - BinaryOperator::Multiply(l, _) => s.push_str(&l.to_string()), - BinaryOperator::Divide(l, _) => s.push_str(&l.to_string()), + match &self.op { + BinaryOp::Add(l, _) => s.push_str(&l.to_string()), + BinaryOp::Subtract(l, _) => s.push_str(&l.to_string()), + BinaryOp::Multiply(l, _) => s.push_str(&l.to_string()), + BinaryOp::Divide(l, _) => s.push_str(&l.to_string()), } s.push(')'); s.push( - match self { - &BinaryOperator::Add(_, _) => '+', - &BinaryOperator::Subtract(_, _) => '-', - &BinaryOperator::Multiply(_, _) => '*', - &BinaryOperator::Divide(_, _) => '/', + match &self.op { + BinaryOp::Add(_, _) => '+', + BinaryOp::Subtract(_, _) => '-', + BinaryOp::Multiply(_, _) => '*', + BinaryOp::Divide(_, _) => '/', } ); s.push('('); - match self { - BinaryOperator::Add(_, r) => s.push_str(&r.to_string()), - BinaryOperator::Subtract(_, r) => s.push_str(&r.to_string()), - BinaryOperator::Multiply(_, r) => s.push_str(&r.to_string()), - BinaryOperator::Divide(_, r) => s.push_str(&r.to_string()), + match &self.op { + BinaryOp::Add(_, r) => s.push_str(&r.to_string()), + BinaryOp::Subtract(_, r) => s.push_str(&r.to_string()), + BinaryOp::Multiply(_, r) => s.push_str(&r.to_string()), + BinaryOp::Divide(_, r) => s.push_str(&r.to_string()), } s.push(')'); @@ -146,12 +233,6 @@ impl ToString for BinaryOperator { impl GazType for BinaryOperator { fn get_base(&self) -> BaseType { - // TODO: This get_base is clearly wrong - match self { - BinaryOperator::Add(l, _) => l.get_base(), - BinaryOperator::Subtract(l, _) => l.get_base(), - BinaryOperator::Multiply(l, _) => l.get_base(), - BinaryOperator::Divide(l, _) => l.get_base(), - } + self.basetype } } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 80f595d..fbb9e64 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,6 +1,8 @@ mod expr; mod statement; +use std::fmt; + use expr::{ Expr, Literal, @@ -39,7 +41,7 @@ impl ToString for Quantifier { } } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Eq)] pub enum BaseType { Int, Real, @@ -64,3 +66,28 @@ impl ToString for BaseType { }.to_string() } } + +impl fmt::Debug for BaseType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name = match self { + BaseType::Int => "integer", + BaseType::Real => "real", + BaseType::Never => "never", + BaseType::Unset => "unset", + }; + + write!(f, "{}", name) + } +} + +impl PartialEq for BaseType { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (&BaseType::Int, &BaseType::Int) => true, + (&BaseType::Real, &BaseType::Real) => true, + (&BaseType::Never, &BaseType::Never) => true, + (&BaseType::Unset, &BaseType::Unset) => false, + _ => false, + } + } +}