From 7e6cb221a94bbaac118fe76afa410bc785b4fc0c Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Fri, 15 Aug 2025 20:04:29 +0800 Subject: [PATCH 1/7] fix: compiler panic when a closure call itself --- src/riscv_generate.ml | 30 ++++++++++++++++++++++-------- test/src/closure05/closure05.ans | 2 ++ test/src/closure05/closure05.mbt | 26 ++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 test/src/closure05/closure05.ans create mode 100644 test/src/closure05/closure05.mbt diff --git a/src/riscv_generate.ml b/src/riscv_generate.ml index e685114..e63fe22 100644 --- a/src/riscv_generate.ml +++ b/src/riscv_generate.ml @@ -14,6 +14,9 @@ let global_vars = ref Stringset.empty (** The function/closure we're currently dealing with *) let current_function = ref "" +(** The origin current function name. Used for judge whether a closure calls itself *) +let current_base_name = ref "" + (** The environment (i.e. the pointer passed to the closure) of current closure *) let current_env = ref unit @@ -806,14 +809,21 @@ let rec do_convert tac (expr: Mcore.expr) = (if Stringset.mem fn !fn_names then Vec.push tac (Call { rd; fn; args }) else - (* Here `fn` is a closure *) - let closure = { name = fn; ty = Mtype.T_bytes } in - let fptr = new_temp Mtype.T_bytes in - Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size }); + if fn = !current_base_name then + let fptr = new_temp Mtype.T_bytes in + Vec.push tac (Load { rd = fptr; rs = !current_env; offset = 0; byte = pointer_size }); + + let args = args @ [!current_env] in + Vec.push tac (CallIndirect { rd; rs = fptr; args }); + else + (* Here `fn` is a closure *) + let closure = { name = fn; ty = Mtype.T_bytes } in + let fptr = new_temp Mtype.T_bytes in + Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size }); - (* Closure, along with environment, should be passed as argument *) - let args = args @ [closure] in - Vec.push tac (CallIndirect { rd; rs = fptr; args })); + (* Closure, along with environment, should be passed as argument *) + let args = args @ [closure] in + Vec.push tac (CallIndirect { rd; rs = fptr; args })); (* If this is a `Join`, then we must jump to the corresponding letfn *) if kind = Join then ( @@ -1081,6 +1091,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* This is a different function from the current one, *) (* so we must protect all global variables before generating body *) let this_fn = !current_function in + let this_base_fn = !current_base_name in let this_env = !current_env in let this_join = !current_join in let this_join_ret = !current_join_ret in @@ -1088,6 +1099,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Set the correct values for this new function *) let fn_name = Printf.sprintf "%s_closure_%s" !current_function name in current_function := fn_name; + current_base_name := name; current_env := fn_env; current_join := ""; current_join_ret := unit; @@ -1098,6 +1110,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Put them back *) current_function := this_fn; + current_base_name := this_base_fn; current_env := this_env; current_join := this_join; current_join_ret := this_join_ret; @@ -1122,7 +1135,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Store environment variables *) List.iter2 (fun (arg: var) offset -> let size = sizeof arg.ty in - if arg.name = !current_function then + if arg.name = !current_base_name then (* This closure captures myself, so I need to make myself a closure *) (* Fortunately my environment is just my closure *) Vec.push tac (Store { rd = !current_env; rs = closure; offset = offset - size; byte = size }) @@ -1720,6 +1733,7 @@ let convert_toplevel _start (top: Mcore.top_item) = let args = List.map var_of_param func.params in current_function := fn; + current_base_name := fn; let body = convert_expr func.body in (* Record the index of arguments that are traits *) diff --git a/test/src/closure05/closure05.ans b/test/src/closure05/closure05.ans new file mode 100644 index 0000000..4f7108e --- /dev/null +++ b/test/src/closure05/closure05.ans @@ -0,0 +1,2 @@ +60 +() diff --git a/test/src/closure05/closure05.mbt b/test/src/closure05/closure05.mbt new file mode 100644 index 0000000..76f191b --- /dev/null +++ b/test/src/closure05/closure05.mbt @@ -0,0 +1,26 @@ +fn create(n: Int) -> (Int) -> (Int) -> (Int) -> Unit { + + fn layer1(x: Int) { + + fn layer2(y: Int) { + + fn layer3(z: Int) { + let _ = if z <= 0 { + println(x + y + n) + } else { + layer3(z - 1) + }; + }; + + layer3 + }; + + layer2 + }; + + layer1 +}; + +fn main { + println(create(10)(20)(30)(1)); +}; \ No newline at end of file From ca7c0728b117538c8a0b60e85096a9593eaab59e Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Wed, 20 Aug 2025 21:02:33 +0800 Subject: [PATCH 2/7] feat: Basic purity table and dead pure function call elimination --- src/riscv_opt_peephole.ml | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/riscv_opt_peephole.ml b/src/riscv_opt_peephole.ml index ce3d5e9..ccbc04e 100644 --- a/src/riscv_opt_peephole.ml +++ b/src/riscv_opt_peephole.ml @@ -116,6 +116,35 @@ let to_itype fn = in List.iter (fun block -> (block_of block).body <- convert block) blocks +let purity_table = Hashtbl.create 64 + +let rec is_pure fn = + match Hashtbl.find_opt purity_table fn with + | Some x -> x + | None -> + let global_vars = !Riscv_generate.global_vars in + let pure = + let blocks = get_blocks fn in + List.for_all (fun block -> + List.for_all (fun x -> + let is_global = ref false in + reg_iterd (fun var -> + if Stringset.mem var.name global_vars then + is_global := true; + var) x; + not !is_global + ) (body_of block) && + List.for_all (fun x -> match x with + | Call { fn = fn' } -> if fn <> fn' then is_pure fn' else true + | CallExtern _ | CallIndirect _ | Store _ + | Malloc _ | Alloca _ -> false + | _ -> true) (body_of block) + ) blocks + in + Hashtbl.add purity_table fn pure; + (* Printf.printf "Checking purity of %s %s\n" fn (string_of_bool pure); *) + pure + let remove_dead_variable fn = let blocks = get_blocks fn in let liveness = liveness_analysis fn in @@ -140,7 +169,8 @@ let remove_dead_variable fn = (* TODO: refine this, so that calls to pure functions are also eliminated *) match x with - | Call _ | CallExtern _ | CallIndirect _ -> true + | Call { fn } when not (is_pure fn) -> true + | CallExtern _ | CallIndirect _ -> true | _ -> !preserve ) body; |> Basic_vec.of_list From 7e127c4b9b2024ee1315c4e95c354eaabd62e82b Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Mon, 25 Aug 2025 20:31:47 +0800 Subject: [PATCH 3/7] Allow pure functions with local stores --- src/riscv_opt_peephole.ml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/riscv_opt_peephole.ml b/src/riscv_opt_peephole.ml index ccbc04e..1eac881 100644 --- a/src/riscv_opt_peephole.ml +++ b/src/riscv_opt_peephole.ml @@ -123,8 +123,9 @@ let rec is_pure fn = | Some x -> x | None -> let global_vars = !Riscv_generate.global_vars in + let alloca_data = ref Stringset.empty in + let blocks = get_blocks fn in let pure = - let blocks = get_blocks fn in List.for_all (fun block -> List.for_all (fun x -> let is_global = ref false in @@ -135,14 +136,15 @@ let rec is_pure fn = not !is_global ) (body_of block) && List.for_all (fun x -> match x with - | Call { fn = fn' } -> if fn <> fn' then is_pure fn' else true - | CallExtern _ | CallIndirect _ | Store _ - | Malloc _ | Alloca _ -> false + | Alloca { rd } -> alloca_data := Stringset.add rd.name !alloca_data; true + | Call { fn = fn' } -> if fn <> fn' then is_pure fn' else false + | Store { rs } -> Stringset.mem rs.name !alloca_data + | CallExtern _ | CallIndirect _ + | Malloc _ -> false | _ -> true) (body_of block) ) blocks in Hashtbl.add purity_table fn pure; - (* Printf.printf "Checking purity of %s %s\n" fn (string_of_bool pure); *) pure let remove_dead_variable fn = From db0df655297d71d0a1be2690410e6eddc428dfdc Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Fri, 29 Aug 2025 16:08:52 +0800 Subject: [PATCH 4/7] Bugfix: an omitted assign to rd in Psetbytesitem --- src/riscv_generate.ml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/riscv_generate.ml b/src/riscv_generate.ml index e63fe22..09a5024 100644 --- a/src/riscv_generate.ml +++ b/src/riscv_generate.ml @@ -402,7 +402,8 @@ let deal_with_prim tac rd (prim: Primitive.prim) args = let altered = new_temp Mtype.T_string in Vec.push tac (Add { rd = altered; rs1 = str; rs2 = i }); - Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 }) + Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 }); + Vec.push tac (Assign { rd; rs = unit }) (* Be cautious that each `char` is 2 bytes long, which is extremely counter-intuitive. *) | Pgetstringitem -> From 957528a6727ea02f6ce7ee03558ad910e67a143c Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Sat, 30 Aug 2025 17:12:49 +0800 Subject: [PATCH 5/7] Bugfix: AssignFP format and rtypeu for unsigned comparison --- src/riscv_ssa.ml | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/riscv_ssa.ml b/src/riscv_ssa.ml index 9e03905..5622ad0 100644 --- a/src/riscv_ssa.ml +++ b/src/riscv_ssa.ml @@ -283,10 +283,17 @@ let to_string t = (* Deals with signedness: signed or unsigned *) (* In most R-type instructions, `rd` and the 2 `rs`es have the same type *) + + (* For unsigned comparison, `rd` may have different type *) + (* SSA interpreter relies on this function, so dealing with `most` is not enough *) let rtypeu op ({ rd; rs1; rs2 }: r_type) = - let width = (match rd.ty with - | T_uint | T_uint64 -> "u" - | _ -> "") in + let isu_rs1 = (match rs1.ty with + | T_uint | T_uint64 -> true + | _ -> false) in + let isu_rs2 = (match rs2.ty with + | T_uint | T_uint64 -> true + | _ -> false) in + let width = if (isu_rs1 && isu_rs2) then "u" else "" in Printf.sprintf "%s%s %s %s %s" op width rd.name rs1.name rs2.name in @@ -406,7 +413,7 @@ let to_string t = Printf.sprintf "li %s %s" rd.name (Int64.to_string imm) | AssignFP { rd; imm; } -> - Printf.sprintf "fli %s = %f" rd.name imm + Printf.sprintf "fli %s %f" rd.name imm | AssignLabel { rd; imm; } -> Printf.sprintf "la %s %s" rd.name imm From 6f566d8c8f014d31b876307a58df44986217d60c Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Sat, 30 Aug 2025 17:49:01 +0800 Subject: [PATCH 6/7] Bugfix: logical not instead of bitwise not --- test/interpreter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interpreter.cpp b/test/interpreter.cpp index 43db751..99685fe 100644 --- a/test/interpreter.cpp +++ b/test/interpreter.cpp @@ -216,7 +216,7 @@ int64_t interpret(std::string label) { } if (op == "not") { - VAL(1) = ~VAL(2); + VAL(1) = VAL(2) == 0; continue; } From da1ed536f8ea9bcd1d0b176176da65be51ea06a5 Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Sat, 30 Aug 2025 17:50:52 +0800 Subject: [PATCH 7/7] Support for floating-point number in interpreter and test cases --- src/riscv_virtasm_generate.ml | 37 ++++++++------- test/interpreter.cpp | 63 +++++++++++++++++++++++++- test/src/floatpoint01/floatpoint01.ans | 6 +++ test/src/floatpoint01/floatpoint01.mbt | 13 ++++++ test/src/floatpoint02/floatpoint02.ans | 6 +++ test/src/floatpoint02/floatpoint02.mbt | 37 +++++++++++++++ 6 files changed, 144 insertions(+), 18 deletions(-) create mode 100644 test/src/floatpoint01/floatpoint01.ans create mode 100644 test/src/floatpoint01/floatpoint01.mbt create mode 100644 test/src/floatpoint02/floatpoint02.ans create mode 100644 test/src/floatpoint02/floatpoint02.mbt diff --git a/src/riscv_virtasm_generate.ml b/src/riscv_virtasm_generate.ml index 24ee9e1..aace00e 100644 --- a/src/riscv_virtasm_generate.ml +++ b/src/riscv_virtasm_generate.ml @@ -375,26 +375,29 @@ let convert_single name body terminator (inst: Riscv_ssa.t) = terminator := Term.J (label_of label) | JumpIndirect { rs; possibilities } -> (* TODO: Optimizations on possibilities *) - terminator := Term.Jalr { - rd = Slot.Reg Zero; - rs1 = slot_v rs; - offset = 0; - }; + if List.length possibilities = 1 then + terminator := Term.J (label_of (List.hd possibilities)) + else + terminator := Term.Jalr { + rd = Slot.Reg Zero; + rs1 = slot_v rs; + offset = 0; + }; (* Floating point instructions *) + | FAdd _ -> () + | FSub _ -> () + | FMul _ -> () + | FDiv _ -> () + | FLess _ -> () + | FLeq _ -> () + | FGreat _ -> () + | FGeq _ -> () + | FEq _ -> () + | FNeq _ -> () + | FNeg _ -> () + | AssignFP _ -> () (* - | FAdd -> _ - | FSub -> _ - | FMul -> _ - | FDiv -> _ - | FLess -> _ - | FLeq -> _ - | FGreat -> _ - | FGeq -> _ - | FEq -> _ - | FNeq -> _ - | FNeg -> _ - | AssignFP -> _ | FnDecl -> _ | GlobalVarDecl -> _ | ExtArray _ -> _ *) diff --git a/test/interpreter.cpp b/test/interpreter.cpp index 99685fe..745496c 100644 --- a/test/interpreter.cpp +++ b/test/interpreter.cpp @@ -19,7 +19,6 @@ std::map> blocks; std::map> fns; // Values of registers used when interpreting. -// TODO: currently no FP supported. std::map regs; std::map labels; @@ -64,13 +63,35 @@ int int_of(std::string s) { #define RTYPEU(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (uint64_t)x op (uint64_t)y; }) #define RTYPEUW(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (unsigned)x op (unsigned)y; }) #define ITYPEW(name, op) std::make_pair(name, [](int64_t x, int imm) -> int64_t { return (int)x op imm; }) +#define RTYPEF(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \ + double fx, fy; \ + std::memcpy(&fx, &x, 8); \ + std::memcpy(&fy, &y, 8); \ + double fres = fx op fy; \ + int64_t res; \ + std::memcpy(&res, &fres, 8); \ + return res; \ + }) +#define RTYPEFCMP(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \ + double fx, fy; \ + std::memcpy(&fx, &x, 8); \ + std::memcpy(&fy, &y, 8); \ + int64_t res = fx op fy; \ + return res; \ + }) #define VAL(i) regs[args[i]] #ifdef VERBOSE #define OUTPUT(name, value) std::cerr << "\t" << name << " = " << value << "\n\n" +#define OUTPUTF(name, value) \ + double fvalue; \ + std::memcpy(&fvalue, &value, 8); \ + std::cerr << "\t" << name << " = " << fvalue << "\n\n" + #define SAY(str) std::cerr << str << "\n" #else #define OUTPUT(name, value) +#define OUTPUTF(name, value) #define SAY(str) #endif @@ -118,6 +139,19 @@ int64_t interpret(std::string label) { RTYPEUW("moduw", %), }; + static std::map> rtypef = { + RTYPEF("fadd", +), + RTYPEF("fsub", -), + RTYPEF("fmul", *), + RTYPEF("fdiv", /), + RTYPEFCMP("fle", <), + RTYPEFCMP("fleq", <=), + RTYPEFCMP("fge", >), + RTYPEFCMP("fgeq", >=), + RTYPEFCMP("feq", ==), + RTYPEFCMP("fneq", !=), + }; + static std::map> load = { MEM("lb", char), MEM("lh", char16_t), @@ -158,6 +192,12 @@ int64_t interpret(std::string label) { continue; } + if (rtypef.contains(op)) { + VAL(1) = rtypef[op](VAL(2), VAL(3)); + OUTPUTF(args[1], VAL(1)); + continue; + } + if (load.contains(op)) { VAL(1) = load[op](VAL(2), int_of(args[3])); OUTPUT(args[1], VAL(1)); @@ -215,6 +255,15 @@ int64_t interpret(std::string label) { continue; } + if (op == "fneg") { + int64_t val = VAL(2); + double fval; + std::memcpy(&fval, &val, 8); + fval = -fval; + std::memcpy(&VAL(1), &fval, 8); + continue; + } + if (op == "not") { VAL(1) = VAL(2) == 0; continue; @@ -379,6 +428,18 @@ int64_t interpret(std::string label) { continue; } + if (op == "fli"){ + std::stringstream ss(args[2]); + double rs; + ss >> rs; + + int64_t rs_int; + std::memcpy(&rs_int, &rs, 8); + VAL(1) = rs_int; + OUTPUT(args[1], rs); + continue; + } + if (op == "mv") { VAL(1) = VAL(2); OUTPUT(args[1], VAL(1)); diff --git a/test/src/floatpoint01/floatpoint01.ans b/test/src/floatpoint01/floatpoint01.ans new file mode 100644 index 0000000..cdb0386 --- /dev/null +++ b/test/src/floatpoint01/floatpoint01.ans @@ -0,0 +1,6 @@ +3.3 +-1.1 +2.42 +0.5 +[1.1, 2.2, 3.3, 4.4, 5.5] +-3.14 diff --git a/test/src/floatpoint01/floatpoint01.mbt b/test/src/floatpoint01/floatpoint01.mbt new file mode 100644 index 0000000..74e69ca --- /dev/null +++ b/test/src/floatpoint01/floatpoint01.mbt @@ -0,0 +1,13 @@ +fn f(x : Double) -> Double { + -x +} + +fn main { + println(1.1 + 2.2) + println(1.1 - 2.2) + println(1.1 * 2.2) + println(1.1 / 2.2) + let arr1 = [1.1, 2.2, 3.3, 4.4, 5.5] + println(arr1) + println(f(3.14)) +} \ No newline at end of file diff --git a/test/src/floatpoint02/floatpoint02.ans b/test/src/floatpoint02/floatpoint02.ans new file mode 100644 index 0000000..17607c6 --- /dev/null +++ b/test/src/floatpoint02/floatpoint02.ans @@ -0,0 +1,6 @@ +4 +3 +7 +6 +1 +9 diff --git a/test/src/floatpoint02/floatpoint02.mbt b/test/src/floatpoint02/floatpoint02.mbt new file mode 100644 index 0000000..311f4a0 --- /dev/null +++ b/test/src/floatpoint02/floatpoint02.mbt @@ -0,0 +1,37 @@ +fn main { + let mut x : Double = 0 + let mut t_geq : Int = 0 + let mut t_ge : Int = 0 + let mut t_leq : Int = 0 + let mut t_le : Int = 0 + let mut t_eq : Int = 0 + let mut t_neq : Int = 0 + + for i = 0; i < 10; i = i + 1 { + x += 0.1 + if (x <= 0.4) { + t_leq += 1 + } + if (x < 0.4) { + t_le += 1 + } + if (x >= 0.4) { + t_geq += 1 + } + if (x > 0.4) { + t_ge += 1 + } + if (x == 0.4) { + t_eq += 1 + } + if (x != 0.4) { + t_neq += 1 + } + } + println(t_leq) + println(t_le) + println(t_geq) + println(t_ge) + println(t_eq) + println(t_neq) +} \ No newline at end of file