Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 56 additions & 103 deletions crates/lean_compiler/src/a_simplify_lang.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::{
Counter, F,
ir::HighLevelOperation,
lang::{
AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line, MathExpr,
Program, Scope, SimpleExpr, Var,
AssignmentTarget, Condition, ConstExpression, ConstMallocLabel, Context, Expression, Function, Line,
MathOperation, Program, Scope, SimpleExpr, Var,
},
parser::ConstArrayValue,
};
use lean_vm::{Boolean, BooleanExpr, CustomHint, FileId, SourceLineNumber, SourceLocation, Table, TableT};
use lean_vm::{
Boolean, BooleanExpr, CustomHint, FileId, FunctionName, SourceLineNumber, SourceLocation, Table, TableT,
};
use std::{
collections::{BTreeMap, BTreeSet},
fmt::{Display, Formatter},
Expand All @@ -16,7 +17,7 @@ use utils::ToUsize;

#[derive(Debug, Clone)]
pub struct SimpleProgram {
pub functions: BTreeMap<String, SimpleFunction>,
pub functions: BTreeMap<FunctionName, SimpleFunction>,
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -79,7 +80,7 @@ pub enum SimpleLine {
},
Assignment {
var: VarOrConstMallocAccess,
operation: HighLevelOperation,
operation: MathOperation, // add / sub / div / mul
arg0: SimpleExpr,
arg1: SimpleExpr,
},
Expand All @@ -95,7 +96,7 @@ pub enum SimpleLine {
line_number: SourceLineNumber,
},
AssertZero {
operation: HighLevelOperation,
operation: MathOperation,
arg0: SimpleExpr,
arg1: SimpleExpr,
},
Expand Down Expand Up @@ -149,7 +150,7 @@ impl SimpleLine {
pub fn equality(arg0: impl Into<VarOrConstMallocAccess>, arg1: impl Into<SimpleExpr>) -> Self {
SimpleLine::Assignment {
var: arg0.into(),
operation: HighLevelOperation::Add,
operation: MathOperation::Add,
arg0: arg1.into(),
arg1: SimpleExpr::zero(),
}
Expand Down Expand Up @@ -485,14 +486,6 @@ fn check_expr_scoping(expr: &Expression, ctx: &Context) {
check_expr_scoping(idx, ctx);
}
}
Expression::Binary {
left,
operation: _,
right,
} => {
check_expr_scoping(left, ctx);
check_expr_scoping(right, ctx);
}
Expression::MathExpr(_, args) => {
for arg in args {
check_expr_scoping(arg, ctx);
Expand Down Expand Up @@ -694,7 +687,6 @@ fn simplify_lines(
}
}
_ => {
// Non-function call - must have exactly one target
assert!(targets.len() == 1, "Non-function call must have exactly one target");
let target = &targets[0];

Expand All @@ -716,29 +708,26 @@ fn simplify_lines(
ArrayAccessType::VarIsAssigned(var.clone()),
);
}
Expression::Binary { left, operation, right } => {
let left = simplify_expr(ctx, state, const_malloc, left, &mut res);
let right = simplify_expr(ctx, state, const_malloc, right, &mut res);
// If both operands are constants, evaluate at compile time and assign the result
if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) =
(&left, &right)
{
let result = ConstExpression::MathExpr(
MathExpr::Binary(*operation),
vec![left_cst.clone(), right_cst.clone()],
)
.try_naive_simplification();
Expression::MathExpr(operation, args) => {
let args_simplified = args
.iter()
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res))
.collect::<Vec<_>>();
// If all operands are constants, evaluate at compile time and assign the result
if let Some(const_args) = SimpleExpr::try_vec_as_constant(&args_simplified) {
let result = ConstExpression::MathExpr(*operation, const_args);
res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result)));
} else {
// general case
res.push(SimpleLine::Assignment {
var: var.clone().into(),
operation: *operation,
arg0: left,
arg1: right,
arg0: args_simplified[0].clone(),
arg1: args_simplified[1].clone(),
});
}
}
Expression::MathExpr(_, _) | Expression::Len { .. } => unreachable!(),
Expression::Len { .. } => unreachable!(),
Expression::FunctionCall { .. } => {
unreachable!("FunctionCall should be handled above")
}
Expand Down Expand Up @@ -783,7 +772,7 @@ fn simplify_lines(
let diff_var = state.counters.aux_var();
res.push(SimpleLine::Assignment {
var: diff_var.clone().into(),
operation: HighLevelOperation::Sub,
operation: MathOperation::Sub,
arg0: left,
arg1: right,
});
Expand Down Expand Up @@ -848,7 +837,7 @@ fn simplify_lines(
let diff_var = state.counters.aux_var();
res.push(SimpleLine::Assignment {
var: diff_var.clone().into(),
operation: HighLevelOperation::Sub,
operation: MathOperation::Sub,
arg0: left_simplified,
arg1: right_simplified,
});
Expand Down Expand Up @@ -1142,9 +1131,8 @@ fn simplify_expr(

if let SimpleExpr::Var(array_var) = array
&& let Some(label) = const_malloc.map.get(array_var)
&& let Ok(mut offset) = ConstExpression::try_from(index.clone())
&& let Ok(offset) = ConstExpression::try_from(index.clone())
{
offset = offset.try_naive_simplification();
return SimpleExpr::ConstMallocAccess {
malloc_label: *label,
offset,
Expand All @@ -1168,37 +1156,24 @@ fn simplify_expr(
);
SimpleExpr::Var(aux_arr)
}
Expression::Binary { left, operation, right } => {
let left_var = simplify_expr(ctx, state, const_malloc, left, lines);
let right_var = simplify_expr(ctx, state, const_malloc, right, lines);

if let (SimpleExpr::Constant(left_cst), SimpleExpr::Constant(right_cst)) = (&left_var, &right_var) {
return SimpleExpr::Constant(ConstExpression::MathExpr(
MathExpr::Binary(*operation),
vec![left_cst.clone(), right_cst.clone()],
));
Expression::MathExpr(operation, args) => {
let simplified_args = args
.iter()
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, lines))
.collect::<Vec<_>>();
if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) {
return SimpleExpr::Constant(ConstExpression::MathExpr(*operation, const_args));
}

let aux_var = state.counters.aux_var();
assert_eq!(simplified_args.len(), 2);
lines.push(SimpleLine::Assignment {
var: aux_var.clone().into(),
operation: *operation,
arg0: left_var,
arg1: right_var,
arg0: simplified_args[0].clone(),
arg1: simplified_args[1].clone(),
});
SimpleExpr::Var(aux_var)
}
Expression::MathExpr(formula, args) => {
let simplified_args = args
.iter()
.map(|arg| {
simplify_expr(ctx, state, const_malloc, arg, lines)
.as_constant()
.unwrap()
})
.collect::<Vec<_>>();
SimpleExpr::Constant(ConstExpression::MathExpr(*formula, simplified_args))
}
Expression::FunctionCall { function_name, args } => {
let function = ctx
.functions
Expand Down Expand Up @@ -1387,10 +1362,6 @@ fn inline_expr(expr: &mut Expression, args: &BTreeMap<Var, SimpleExpr>, inlining
inline_expr(idx, args, inlining_count);
}
}
Expression::Binary { left, right, .. } => {
inline_expr(left, args, inlining_count);
inline_expr(right, args, inlining_count);
}
Expression::MathExpr(_, math_args) => {
for arg in math_args {
inline_expr(arg, args, inlining_count);
Expand Down Expand Up @@ -1556,10 +1527,6 @@ fn vars_in_expression(expr: &Expression, const_arrays: &BTreeMap<String, ConstAr
vars.extend(vars_in_expression(idx, const_arrays));
}
}
Expression::Binary { left, right, .. } => {
vars.extend(vars_in_expression(left, const_arrays));
vars.extend(vars_in_expression(right, const_arrays));
}
Expression::MathExpr(_, args) => {
for arg in args {
vars.extend(vars_in_expression(arg, const_arrays));
Expand Down Expand Up @@ -1625,18 +1592,26 @@ fn handle_array_assignment(
&& let SimpleExpr::Constant(offset) = simplified_index[0].clone()
&& let SimpleExpr::Var(array_var) = &array
&& let Some(label) = const_malloc.map.get(array_var)
&& let ArrayAccessType::ArrayIsAssigned(Expression::Binary { left, operation, right }) = &access_type
&& let ArrayAccessType::ArrayIsAssigned(Expression::MathExpr(operation, args)) = &access_type
{
let arg0 = simplify_expr(ctx, state, const_malloc, left, res);
let arg1 = simplify_expr(ctx, state, const_malloc, right, res);
let var = VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
};
let simplified_args = args
.iter()
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, res))
.collect::<Vec<_>>();
if let Some(const_args) = SimpleExpr::try_vec_as_constant(&simplified_args) {
let result = ConstExpression::MathExpr(*operation, const_args);
res.push(SimpleLine::equality(var.clone(), SimpleExpr::Constant(result.clone())));
}
assert_eq!(simplified_args.len(), 2);
res.push(SimpleLine::Assignment {
var: VarOrConstMallocAccess::ConstMallocAccess {
malloc_label: *label,
offset,
},
var,
operation: *operation,
arg0,
arg1,
arg0: simplified_args[0].clone(),
arg1: simplified_args[1].clone(),
});
return;
}
Expand All @@ -1656,7 +1631,7 @@ fn handle_array_assignment(
let ptr_var = state.counters.aux_var();
res.push(SimpleLine::Assignment {
var: ptr_var.clone().into(),
operation: HighLevelOperation::Add,
operation: MathOperation::Add,
arg0: array,
arg1: simplified_index,
});
Expand Down Expand Up @@ -1684,7 +1659,7 @@ fn create_recursive_function(
let next_iter = format!("@incremented_{iterator}");
body.push(SimpleLine::Assignment {
var: next_iter.clone().into(),
operation: HighLevelOperation::Add,
operation: MathOperation::Add,
arg0: iterator.clone().into(),
arg1: SimpleExpr::one(),
});
Expand All @@ -1706,7 +1681,7 @@ fn create_recursive_function(
let instructions = vec![
SimpleLine::Assignment {
var: diff_var.clone().into(),
operation: HighLevelOperation::Sub,
operation: MathOperation::Sub,
arg0: iterator.into(),
arg1: end,
},
Expand Down Expand Up @@ -1756,10 +1731,6 @@ fn replace_vars_for_unroll_in_expr(
replace_vars_for_unroll_in_expr(index, iterator, unroll_index, iterator_value, internal_vars);
}
}
Expression::Binary { left, right, .. } => {
replace_vars_for_unroll_in_expr(left, iterator, unroll_index, iterator_value, internal_vars);
replace_vars_for_unroll_in_expr(right, iterator, unroll_index, iterator_value, internal_vars);
}
Expression::MathExpr(_, args) => {
for arg in args {
replace_vars_for_unroll_in_expr(arg, iterator, unroll_index, iterator_value, internal_vars);
Expand Down Expand Up @@ -2033,28 +2004,14 @@ fn extract_inlined_calls_from_expr(
lines,
)
}
Expression::Binary { left, operation, right } => {
let (left, left_lines) = extract_inlined_calls_from_expr(left, inlined_functions, inlined_var_counter);
lines.extend(left_lines);
let (right, right_lines) = extract_inlined_calls_from_expr(right, inlined_functions, inlined_var_counter);
lines.extend(right_lines);
(
Expression::Binary {
left: Box::new(left),
operation: *operation,
right: Box::new(right),
},
lines,
)
}
Expression::MathExpr(formula, args) => {
Expression::MathExpr(operation, args) => {
let mut args_new = vec![];
for arg in args {
let (arg, arg_lines) = extract_inlined_calls_from_expr(arg, inlined_functions, inlined_var_counter);
lines.extend(arg_lines);
args_new.push(arg);
}
(Expression::MathExpr(*formula, args_new), lines)
(Expression::MathExpr(*operation, args_new), lines)
}
Expression::FunctionCall { function_name, args } => {
let mut args_new = vec![];
Expand Down Expand Up @@ -2636,10 +2593,6 @@ fn replace_vars_by_const_in_expr(expr: &mut Expression, map: &BTreeMap<Var, F>)
replace_vars_by_const_in_expr(index, map);
}
}
Expression::Binary { left, right, .. } => {
replace_vars_by_const_in_expr(left, map);
replace_vars_by_const_in_expr(right, map);
}
Expression::MathExpr(_, args) => {
for arg in args {
replace_vars_by_const_in_expr(arg, map);
Expand Down
5 changes: 1 addition & 4 deletions crates/lean_compiler/src/b_compile_intermediate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ impl Compiler {
VarOrConstMallocAccess::ConstMallocAccess { malloc_label, offset } => {
for scope in self.stack_frame_layout.scopes.iter().rev() {
if let Some(base) = scope.const_mallocs.get(malloc_label) {
return ConstExpression::MathExpr(
MathExpr::Binary(HighLevelOperation::Add),
vec![(*base).into(), offset.clone()],
);
return ConstExpression::MathExpr(MathOperation::Add, vec![(*base).into(), offset.clone()]);
}
}
panic!("Const malloc {malloc_label} not in scope");
Expand Down
Loading