From 168476bae777dc9f6561a8d720cf02b67c00ecc1 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 31 Dec 2025 01:16:14 +0100 Subject: [PATCH] simpler math exprs --- crates/lean_compiler/src/a_simplify_lang.rs | 159 ++++++------------ .../src/b_compile_intermediate.rs | 5 +- crates/lean_compiler/src/ir/instruction.rs | 21 ++- crates/lean_compiler/src/ir/mod.rs | 2 - crates/lean_compiler/src/ir/operation.rs | 59 ------- crates/lean_compiler/src/lang.rs | 120 +++++++------ .../src/parser/parsers/expression.rs | 82 +++------ 7 files changed, 161 insertions(+), 287 deletions(-) delete mode 100644 crates/lean_compiler/src/ir/operation.rs diff --git a/crates/lean_compiler/src/a_simplify_lang.rs b/crates/lean_compiler/src/a_simplify_lang.rs index cbaf087d..2f30428a 100644 --- a/crates/lean_compiler/src/a_simplify_lang.rs +++ b/crates/lean_compiler/src/a_simplify_lang.rs @@ -1,13 +1,14 @@ use crate::{ Counter, F, - ir::HighLevelOperation, lang::{ - AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, MathExpr, - Program, Scope, SimpleExpr, Var, + AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, + MathOperation, Program, Scope, SimpleExpr, Var, }, parser::ConstArrayValue, }; -use lean_vm::{Boolean, BooleanExpr, CustomHint, FileId, SourceLineNumber, SourceLocation, Table, TableT}; +use lean_vm::{ + Boolean, BooleanExpr, CustomHint, FileId, FunctionName, SourceLineNumber, SourceLocation, Table, TableT, +}; use std::{ collections::{BTreeMap, BTreeSet}, fmt::{Display, Formatter}, @@ -16,7 +17,7 @@ use utils::ToUsize; #[derive(Debug, Clone)] pub struct SimpleProgram { - pub functions: BTreeMap, + pub functions: BTreeMap, } #[derive(Debug, Clone)] @@ -79,7 +80,7 @@ pub enum SimpleLine { }, Assignment { var: VarOrConstMallocAccess, - operation: HighLevelOperation, + operation: MathOperation, // add / sub / div / mul arg0: SimpleExpr, arg1: SimpleExpr, }, @@ -95,7 +96,7 @@ pub enum SimpleLine { line_number: SourceLineNumber, }, AssertZero { - operation: HighLevelOperation, + operation: MathOperation, arg0: SimpleExpr, arg1: SimpleExpr, }, @@ -149,7 +150,7 @@ impl SimpleLine { pub fn equality(arg0: impl Into, arg1: impl Into) -> Self { SimpleLine::Assignment { var: arg0.into(), - operation: HighLevelOperation::Add, + operation: MathOperation::Add, arg0: arg1.into(), arg1: SimpleExpr::zero(), } @@ -485,14 +486,6 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) { check_expr_scoping(idx, ctx); } } - Expression::Binary { - left, - operation: _, - right, - } => { - check_expr_scoping(left, ctx); - check_expr_scoping(right, ctx); - } Expression::MathExpr(_, args) => { for arg in args { check_expr_scoping(arg, ctx); @@ -694,7 +687,6 @@ fn simplify_lines( } } _ => { - // Non-function call - must have exactly one target assert!(targets.len() == 1, "Non-function call must have exactly one target"); let target = &targets[0]; @@ -716,29 +708,26 @@ fn simplify_lines( ArrayAccessType::VarIsAssigned(var.clone()), ); } - Expression::Binary { left, operation, right } => { - let left = simplify_expr(ctx, state, const_malloc, left, &mut res); - let right = simplify_expr(ctx, state, const_malloc, right, &mut res); - // If both operands are constants, evaluate at compile time and assign the result - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = - (&left, &right) - { - let result = ConstExpression::MathExpr( - MathExpr::Binary(*operation), - vec![left_cst.clone(), right_cst.clone()], - ) - .try_naive_simplification(); + Expression::MathExpr(operation, args) => { + let args_simplified = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::>(); + // If all operands are constants, evaluate at compile time and assign the result + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&args_simplified) { + let result = ConstExpression::MathExpr(*operation, const_args); res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result))); } else { + // general case res.push(SimpleLine::Assignment { var: var.clone().into(), operation: *operation, - arg0: left, - arg1: right, + arg0: args_simplified[0].clone(), + arg1: args_simplified[1].clone(), }); } } - Expression::MathExpr(_, _) | Expression::Len { .. } => unreachable!(), + Expression::Len { .. } => unreachable!(), Expression::FunctionCall { .. } => { unreachable!("FunctionCall should be handled above") } @@ -783,7 +772,7 @@ fn simplify_lines( let diff_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, + operation: MathOperation::Sub, arg0: left, arg1: right, }); @@ -848,7 +837,7 @@ fn simplify_lines( let diff_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, + operation: MathOperation::Sub, arg0: left_simplified, arg1: right_simplified, }); @@ -1142,9 +1131,8 @@ fn simplify_expr( if let SimpleExpr::Var(array_var) = array && let Some(label) = const_malloc.map.get(array_var) - && let Ok(mut offset) = ConstExpression::try_from(index.clone()) + && let Ok(offset) = ConstExpression::try_from(index.clone()) { - offset = offset.try_naive_simplification(); return SimpleExpr::ConstMallocAccess { malloc_label: *label, offset, @@ -1168,37 +1156,24 @@ fn simplify_expr( ); SimpleExpr::Var(aux_arr) } - Expression::Binary { left, operation, right } => { - let left_var = simplify_expr(ctx, state, const_malloc, left, lines); - let right_var = simplify_expr(ctx, state, const_malloc, right, lines); - - if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) { - return SimpleExpr::Constant(ConstExpression::MathExpr( - MathExpr::Binary(*operation), - vec![left_cst.clone(), right_cst.clone()], - )); + Expression::MathExpr(operation, args) => { + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines)) + .collect::>(); + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { + return SimpleExpr::Constant(ConstExpression::MathExpr(*operation, const_args)); } - let aux_var = state.counters.aux_var(); + assert_eq!(simplified_args.len(), 2); lines.push(SimpleLine::Assignment { var: aux_var.clone().into(), operation: *operation, - arg0: left_var, - arg1: right_var, + arg0: simplified_args[0].clone(), + arg1: simplified_args[1].clone(), }); SimpleExpr::Var(aux_var) } - Expression::MathExpr(formula, args) => { - let simplified_args = args - .iter() - .map(|arg| { - simplify_expr(ctx, state, const_malloc, arg, lines) - .as_constant() - .unwrap() - }) - .collect::>(); - SimpleExpr::Constant(ConstExpression::MathExpr(*formula, simplified_args)) - } Expression::FunctionCall { function_name, args } => { let function = ctx .functions @@ -1387,10 +1362,6 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap, inlining inline_expr(idx, args, inlining_count); } } - Expression::Binary { left, right, .. } => { - inline_expr(left, args, inlining_count); - inline_expr(right, args, inlining_count); - } Expression::MathExpr(_, math_args) => { for arg in math_args { inline_expr(arg, args, inlining_count); @@ -1556,10 +1527,6 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap { - vars.extend(vars_in_expression(left, const_arrays)); - vars.extend(vars_in_expression(right, const_arrays)); - } Expression::MathExpr(_, args) => { for arg in args { vars.extend(vars_in_expression(arg, const_arrays)); @@ -1625,18 +1592,26 @@ fn handle_array_assignment( && let SimpleExpr::Constant(offset) = simplified_index[0].clone() && let SimpleExpr::Var(array_var) = &array && let Some(label) = const_malloc.map.get(array_var) - && let ArrayAccessType::ArrayIsAssigned(Expression::Binary { left, operation, right }) = &access_type + && let ArrayAccessType::ArrayIsAssigned(Expression::MathExpr(operation, args)) = &access_type { - let arg0 = simplify_expr(ctx, state, const_malloc, left, res); - let arg1 = simplify_expr(ctx, state, const_malloc, right, res); + let var = VarOrConstMallocAccess::ConstMallocAccess { + malloc_label: *label, + offset, + }; + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, res)) + .collect::>(); + if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) { + let result = ConstExpression::MathExpr(*operation, const_args); + res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result.clone()))); + } + assert_eq!(simplified_args.len(), 2); res.push(SimpleLine::Assignment { - var: VarOrConstMallocAccess::ConstMallocAccess { - malloc_label: *label, - offset, - }, + var, operation: *operation, - arg0, - arg1, + arg0: simplified_args[0].clone(), + arg1: simplified_args[1].clone(), }); return; } @@ -1656,7 +1631,7 @@ fn handle_array_assignment( let ptr_var = state.counters.aux_var(); res.push(SimpleLine::Assignment { var: ptr_var.clone().into(), - operation: HighLevelOperation::Add, + operation: MathOperation::Add, arg0: array, arg1: simplified_index, }); @@ -1684,7 +1659,7 @@ fn create_recursive_function( let next_iter = format!("@incremented_{iterator}"); body.push(SimpleLine::Assignment { var: next_iter.clone().into(), - operation: HighLevelOperation::Add, + operation: MathOperation::Add, arg0: iterator.clone().into(), arg1: SimpleExpr::one(), }); @@ -1706,7 +1681,7 @@ fn create_recursive_function( let instructions = vec![ SimpleLine::Assignment { var: diff_var.clone().into(), - operation: HighLevelOperation::Sub, + operation: MathOperation::Sub, arg0: iterator.into(), arg1: end, }, @@ -1756,10 +1731,6 @@ fn replace_vars_for_unroll_in_expr( replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars); } } - Expression::Binary { left, right, .. } => { - replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars); - replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars); - } Expression::MathExpr(_, args) => { for arg in args { replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars); @@ -2033,28 +2004,14 @@ fn extract_inlined_calls_from_expr( lines, ) } - Expression::Binary { left, operation, right } => { - let (left, left_lines) = extract_inlined_calls_from_expr(left, inlined_functions, inlined_var_counter); - lines.extend(left_lines); - let (right, right_lines) = extract_inlined_calls_from_expr(right, inlined_functions, inlined_var_counter); - lines.extend(right_lines); - ( - Expression::Binary { - left: Box::new(left), - operation: *operation, - right: Box::new(right), - }, - lines, - ) - } - Expression::MathExpr(formula, args) => { + Expression::MathExpr(operation, args) => { let mut args_new = vec![]; for arg in args { let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter); lines.extend(arg_lines); args_new.push(arg); } - (Expression::MathExpr(*formula, args_new), lines) + (Expression::MathExpr(*operation, args_new), lines) } Expression::FunctionCall { function_name, args } => { let mut args_new = vec![]; @@ -2636,10 +2593,6 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap) replace_vars_by_const_in_expr(index, map); } } - Expression::Binary { left, right, .. } => { - replace_vars_by_const_in_expr(left, map); - replace_vars_by_const_in_expr(right, map); - } Expression::MathExpr(_, args) => { for arg in args { replace_vars_by_const_in_expr(arg, map); diff --git a/crates/lean_compiler/src/b_compile_intermediate.rs b/crates/lean_compiler/src/b_compile_intermediate.rs index ddc1cc53..3cb3762d 100644 --- a/crates/lean_compiler/src/b_compile_intermediate.rs +++ b/crates/lean_compiler/src/b_compile_intermediate.rs @@ -51,10 +51,7 @@ impl Compiler { VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => { for scope in self.stack_frame_layout.scopes.iter().rev() { if let Some(base) = scope.const_mallocs.get(malloc_label) { - return ConstExpression::MathExpr( - MathExpr::Binary(HighLevelOperation::Add), - vec![(*base).into(), offset.clone()], - ); + return ConstExpression::MathExpr(MathOperation::Add, vec![(*base).into(), offset.clone()]); } } panic!("Const malloc {malloc_label} not in scope"); diff --git a/crates/lean_compiler/src/ir/instruction.rs b/crates/lean_compiler/src/ir/instruction.rs index 5fc66446..02fc285d 100644 --- a/crates/lean_compiler/src/ir/instruction.rs +++ b/crates/lean_compiler/src/ir/instruction.rs @@ -1,6 +1,5 @@ -use super::operation::HighLevelOperation; use super::value::IntermediateValue; -use crate::lang::ConstExpression; +use crate::lang::{ConstExpression, MathOperation}; use lean_vm::{BooleanExpr, CustomHint, Operation, SourceLocation, Table, TableT}; use std::fmt::{Display, Formatter}; @@ -64,37 +63,43 @@ pub enum IntermediateInstruction { impl IntermediateInstruction { pub fn computation( - operation: HighLevelOperation, + operation: MathOperation, arg_a: IntermediateValue, arg_c: IntermediateValue, res: IntermediateValue, ) -> Self { match operation { - HighLevelOperation::Add => Self::Computation { + MathOperation::Add => Self::Computation { operation: Operation::Add, arg_a, arg_c, res, }, - HighLevelOperation::Mul => Self::Computation { + MathOperation::Mul => Self::Computation { operation: Operation::Mul, arg_a, arg_c, res, }, - HighLevelOperation::Sub => Self::Computation { + MathOperation::Sub => Self::Computation { operation: Operation::Add, arg_a: res, arg_c, res: arg_a, }, - HighLevelOperation::Div => Self::Computation { + MathOperation::Div => Self::Computation { operation: Operation::Mul, arg_a: res, arg_c, res: arg_a, }, - HighLevelOperation::Exp | HighLevelOperation::Mod => unreachable!(), + MathOperation::Exp + | MathOperation::Mod + | MathOperation::NextMultipleOf + | MathOperation::SaturatingSub + | MathOperation::Log2Ceil => { + unreachable!() + } } } diff --git a/crates/lean_compiler/src/ir/mod.rs b/crates/lean_compiler/src/ir/mod.rs index 97bb17fc..c43fddc7 100644 --- a/crates/lean_compiler/src/ir/mod.rs +++ b/crates/lean_compiler/src/ir/mod.rs @@ -2,10 +2,8 @@ pub mod bytecode; pub mod instruction; -pub mod operation; pub mod value; pub use bytecode::{IntermediateBytecode, MatchBlock}; pub use instruction::IntermediateInstruction; -pub use operation::HighLevelOperation; pub use value::IntermediateValue; diff --git a/crates/lean_compiler/src/ir/operation.rs b/crates/lean_compiler/src/ir/operation.rs deleted file mode 100644 index 59a5107a..00000000 --- a/crates/lean_compiler/src/ir/operation.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::F; -use lean_vm::Operation; -use multilinear_toolkit::prelude::*; -use std::fmt::{Display, Formatter}; -use utils::ToUsize; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum HighLevelOperation { - /// Addition operation. - Add, - /// Multiplication operation. - Mul, - /// Subtraction operation (compiled to addition with negation). - Sub, - /// Division operation (compiled to multiplication with inverse). - Div, - /// Exponentiation (only for constant expressions). - Exp, - /// Modulo operation (only for constant expressions). - Mod, -} - -impl HighLevelOperation { - pub fn eval(&self, a: F, b: F) -> F { - match self { - Self::Add => a + b, - Self::Mul => a * b, - Self::Sub => a - b, - Self::Div => a / b, - Self::Exp => a.exp_u64(b.as_canonical_u64()), - Self::Mod => F::from_usize(a.to_usize() % b.to_usize()), - } - } -} - -impl Display for HighLevelOperation { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Add => write!(f, "+"), - Self::Mul => write!(f, "*"), - Self::Sub => write!(f, "-"), - Self::Div => write!(f, "/"), - Self::Exp => write!(f, "**"), - Self::Mod => write!(f, "%"), - } - } -} - -impl TryFrom for Operation { - type Error = String; - - fn try_from(value: HighLevelOperation) -> Result { - match value { - HighLevelOperation::Add => Ok(Self::Add), - HighLevelOperation::Mul => Ok(Self::Mul), - _ => Err(format!("Cannot convert {value:?} to +/x")), - } - } -} diff --git a/crates/lean_compiler/src/lang.rs b/crates/lean_compiler/src/lang.rs index de35994a..3ceda7fc 100644 --- a/crates/lean_compiler/src/lang.rs +++ b/crates/lean_compiler/src/lang.rs @@ -5,7 +5,7 @@ use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Display, Formatter}; use utils::ToUsize; -use crate::{F, ir::HighLevelOperation, parser::ConstArrayValue}; +use crate::{F, parser::ConstArrayValue}; pub use lean_vm::{FileId, FunctionName, SourceLocation}; #[derive(Debug, Clone)] @@ -66,7 +66,7 @@ impl SimpleExpr { pub fn simplify_if_const(&self) -> Self { if let Self::Constant(constant) = self { - return constant.try_naive_simplification().into(); + return constant.clone().into(); } self.clone() } @@ -98,6 +98,18 @@ impl SimpleExpr { Self::ConstMallocAccess { .. } => None, } } + + pub fn try_vec_as_constant(vec: &[Self]) -> Option> { + let mut const_elems = Vec::new(); + for expr in vec { + if let Self::Constant(cst) = expr { + const_elems.push(cst.clone()); + } else { + return None; + } + } + Some(const_elems) + } } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -115,7 +127,7 @@ pub enum ConstantValue { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum ConstExpression { Value(ConstantValue), - MathExpr(MathExpr, Vec), + MathExpr(MathOperation, Vec), } impl From for ConstExpression { @@ -132,11 +144,6 @@ impl TryFrom for ConstExpression { Expression::Value(SimpleExpr::Constant(const_expr)) => Ok(const_expr), Expression::Value(_) => Err(()), Expression::ArrayAccess { .. } => Err(()), - Expression::Binary { left, operation, right } => { - let left_expr = Self::try_from(*left)?; - let right_expr = Self::try_from(*right)?; - Ok(Self::MathExpr(MathExpr::Binary(operation), vec![left_expr, right_expr])) - } Expression::MathExpr(math_expr, args) => { let mut const_args = Vec::new(); for arg in args { @@ -192,14 +199,6 @@ impl ConstExpression { _ => None, }) } - - pub fn try_naive_simplification(&self) -> Self { - if let Some(value) = self.naive_eval() { - Self::scalar(value.to_usize()) - } else { - self.clone() - } - } } impl From for ConstExpression { @@ -230,12 +229,7 @@ pub enum Expression { array: SimpleExpr, index: Vec, // multi-dimensional array access }, - Binary { - left: Box, - operation: HighLevelOperation, - right: Box, - }, - MathExpr(MathExpr, Vec), + MathExpr(MathOperation, Vec), FunctionCall { function_name: String, args: Vec, @@ -248,17 +242,48 @@ pub enum Expression { /// For arbitrary compile-time computations #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub enum MathExpr { - Binary(HighLevelOperation), +pub enum MathOperation { + /// Addition operation. + Add, + /// Multiplication operation. + Mul, + /// Subtraction operation (compiled to addition with negation). + Sub, + /// Division operation (compiled to multiplication with inverse). + Div, + /// Exponentiation (only for constant expressions). + Exp, + /// Modulo operation (only for constant expressions). + Mod, + /// Logarithm ceiling Log2Ceil, + /// similar to rust's next_multiple_of NextMultipleOf, + /// saturating subtraction SaturatingSub, } -impl Display for MathExpr { +impl TryFrom for Operation { + type Error = String; + + fn try_from(value: MathOperation) -> Result { + match value { + MathOperation::Add => Ok(Self::Add), + MathOperation::Mul => Ok(Self::Mul), + _ => Err(format!("Cannot convert {value:?} to add/mul operation")), + } + } +} + +impl Display for MathOperation { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Binary(op) => write!(f, "{op}"), + Self::Add => write!(f, "add"), + Self::Mul => write!(f, "mul"), + Self::Sub => write!(f, "sub"), + Self::Div => write!(f, "div"), + Self::Exp => write!(f, "exp"), + Self::Mod => write!(f, "mod"), Self::Log2Ceil => write!(f, "log2_ceil"), Self::NextMultipleOf => write!(f, "next_multiple_of"), Self::SaturatingSub => write!(f, "saturating_sub"), @@ -266,28 +291,31 @@ impl Display for MathExpr { } } -impl MathExpr { +impl MathOperation { pub fn num_args(&self) -> usize { match self { - Self::Binary(_) => 2, Self::Log2Ceil => 1, - Self::NextMultipleOf => 2, - Self::SaturatingSub => 2, + Self::Add + | Self::Mul + | Self::Sub + | Self::Div + | Self::Exp + | Self::Mod + | Self::NextMultipleOf + | Self::SaturatingSub => 2, } } pub fn eval(&self, args: &[F]) -> F { + assert_eq!(args.len(), self.num_args()); match self { - Self::Binary(op) => { - assert_eq!(args.len(), 2); - op.eval(args[0], args[1]) - } - Self::Log2Ceil => { - assert_eq!(args.len(), 1); - let value = args[0]; - F::from_usize(log2_ceil_usize(value.to_usize())) - } + Self::Add => args[0] + args[1], + Self::Mul => args[0] * args[1], + Self::Sub => args[0] - args[1], + Self::Div => args[0] / args[1], + Self::Exp => args[0].exp_u64(args[1].as_canonical_u64()), + Self::Mod => F::from_usize(args[0].to_usize() % args[1].to_usize()), + Self::Log2Ceil => F::from_usize(log2_ceil_usize(args[0].to_usize())), Self::NextMultipleOf => { - assert_eq!(args.len(), 2); let value = args[0]; let multiple = args[1]; let value_usize = value.to_usize(); @@ -295,10 +323,7 @@ impl MathExpr { let res = value_usize.next_multiple_of(multiple_usize); F::from_usize(res) } - Self::SaturatingSub => { - assert_eq!(args.len(), 2); - F::from_usize(args[0].to_usize().saturating_sub(args[1].to_usize())) - } + Self::SaturatingSub => F::from_usize(args[0].to_usize().saturating_sub(args[1].to_usize())), } } } @@ -356,10 +381,6 @@ impl Expression { .map(|e| e.eval_with(value_fn, array_fn)) .collect::>>()?, ), - Self::Binary { left, operation, right } => Some(operation.eval( - left.eval_with(value_fn, array_fn)?, - right.eval_with(value_fn, array_fn)?, - )), Self::MathExpr(math_expr, args) => { let mut eval_args = Vec::new(); for arg in args { @@ -512,9 +533,6 @@ impl Display for Expression { Self::ArrayAccess { array, index } => { write!(f, "{array}[{index:?}]") } - Self::Binary { left, operation, right } => { - write!(f, "({left} {operation} {right})") - } Self::MathExpr(math_expr, args) => { let args_str = args.iter().map(|arg| format!("{arg}")).collect::>().join(", "); write!(f, "{math_expr}({args_str})") diff --git a/crates/lean_compiler/src/parser/parsers/expression.rs b/crates/lean_compiler/src/parser/parsers/expression.rs index 4a696ae0..bf331a84 100644 --- a/crates/lean_compiler/src/parser/parsers/expression.rs +++ b/crates/lean_compiler/src/parser/parsers/expression.rs @@ -1,8 +1,7 @@ use super::literal::{VarOrConstantParser, evaluate_const_expr}; use super::{ConstArrayValue, Parse, ParseContext, next_inner_pair}; -use crate::lang::MathExpr; +use crate::lang::MathOperation; use crate::{ - ir::HighLevelOperation, lang::{ConstExpression, ConstantValue, Expression, SimpleExpr}, parser::{ error::{ParseResult, SemanticError}, @@ -19,66 +18,42 @@ impl Parse for ExpressionParser { let inner = next_inner_pair(&mut pair.into_inner(), "expression body")?; Self.parse(inner, ctx) } - Rule::add_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Add), - Rule::sub_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Sub), - Rule::mul_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mul), - Rule::mod_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Mod), - Rule::div_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Div), - Rule::exp_expr => BinaryExpressionParser::parse_with_op(pair, ctx, HighLevelOperation::Exp), - Rule::primary => PrimaryExpressionParser.parse(pair, ctx), + Rule::add_expr => MathOperation::Add.parse(pair, ctx), + Rule::sub_expr => MathOperation::Sub.parse(pair, ctx), + Rule::mul_expr => MathOperation::Mul.parse(pair, ctx), + Rule::mod_expr => MathOperation::Mod.parse(pair, ctx), + Rule::div_expr => MathOperation::Div.parse(pair, ctx), + Rule::exp_expr => MathOperation::Exp.parse(pair, ctx), + Rule::log2_ceil_expr => MathOperation::Log2Ceil.parse(pair, ctx), + Rule::next_multiple_of_expr => MathOperation::NextMultipleOf.parse(pair, ctx), + Rule::saturating_sub_expr => MathOperation::SaturatingSub.parse(pair, ctx), + Rule::var_or_constant => Ok(Expression::Value(VarOrConstantParser.parse(pair, ctx)?)), + Rule::array_access_expr => ArrayAccessParser.parse(pair, ctx), + Rule::len_expr => LenParser.parse(pair, ctx), + Rule::function_call_expr => FunctionCallExprParser.parse(pair, ctx), + Rule::primary => { + let inner = next_inner_pair(&mut pair.into_inner(), "primary expression")?; + Self.parse(inner, ctx) + } other_rule => Err(SemanticError::new(format!("ExpressionParser: Unexpected rule {other_rule:?}")).into()), } } } -pub struct BinaryExpressionParser; - -impl BinaryExpressionParser { - pub fn parse_with_op( - pair: ParsePair<'_>, - ctx: &mut ParseContext, - operation: HighLevelOperation, - ) -> ParseResult { +impl Parse for MathOperation { + fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { let mut inner = pair.into_inner(); - let mut expr = ExpressionParser.parse(next_inner_pair(&mut inner, "binary left")?, ctx)?; + let mut expr = ExpressionParser.parse(next_inner_pair(&mut inner, "math expr left")?, ctx)?; for right in inner { let right_expr = ExpressionParser.parse(right, ctx)?; - expr = Expression::Binary { - left: Box::new(expr), - operation, - right: Box::new(right_expr), - }; + expr = Expression::MathExpr(*self, vec![expr, right_expr]); } Ok(expr) } } -/// Parser for primary expressions (variables, constants, parenthesized expressions). -pub struct PrimaryExpressionParser; - -impl Parse for PrimaryExpressionParser { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let inner = next_inner_pair(&mut pair.into_inner(), "primary expression")?; - - match inner.as_rule() { - Rule::expression => ExpressionParser.parse(inner, ctx), - Rule::var_or_constant => { - let simple_expr = VarOrConstantParser.parse(inner, ctx)?; - Ok(Expression::Value(simple_expr)) - } - Rule::array_access_expr => ArrayAccessParser.parse(inner, ctx), - Rule::log2_ceil_expr => MathExpr::Log2Ceil.parse(inner, ctx), - Rule::next_multiple_of_expr => MathExpr::NextMultipleOf.parse(inner, ctx), - Rule::saturating_sub_expr => MathExpr::SaturatingSub.parse(inner, ctx), - Rule::len_expr => LenParser.parse(inner, ctx), - Rule::function_call_expr => FunctionCallExprParser.parse(inner, ctx), - _ => Err(SemanticError::new("Invalid primary expression").into()), - } - } -} - pub struct FunctionCallExprParser; impl Parse for FunctionCallExprParser { @@ -118,19 +93,6 @@ impl Parse for ArrayAccessParser { } } -impl Parse for MathExpr { - fn parse(&self, pair: ParsePair<'_>, ctx: &mut ParseContext) -> ParseResult { - let mut inner = pair.into_inner(); - let mut args = Vec::new(); - for i in 0..self.num_args() { - let expr = - ExpressionParser.parse(next_inner_pair(&mut inner, &format!("math expr arg {}", i + 1))?, ctx)?; - args.push(expr); - } - Ok(Expression::MathExpr(*self, args)) - } -} - /// Parser for len() expressions on const arrays (supports indexed access like len(ARR[i])). pub struct LenParser;