diff --git a/src/riscv_generate.ml b/src/riscv_generate.ml index e685114..09a5024 100644 --- a/src/riscv_generate.ml +++ b/src/riscv_generate.ml @@ -14,6 +14,9 @@ let global_vars = ref Stringset.empty (** The function/closure we're currently dealing with *) let current_function = ref "" +(** The origin current function name. Used for judge whether a closure calls itself *) +let current_base_name = ref "" + (** The environment (i.e. the pointer passed to the closure) of current closure *) let current_env = ref unit @@ -399,7 +402,8 @@ let deal_with_prim tac rd (prim: Primitive.prim) args = let altered = new_temp Mtype.T_string in Vec.push tac (Add { rd = altered; rs1 = str; rs2 = i }); - Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 }) + Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 }); + Vec.push tac (Assign { rd; rs = unit }) (* Be cautious that each `char` is 2 bytes long, which is extremely counter-intuitive. *) | Pgetstringitem -> @@ -806,14 +810,21 @@ let rec do_convert tac (expr: Mcore.expr) = (if Stringset.mem fn !fn_names then Vec.push tac (Call { rd; fn; args }) else - (* Here `fn` is a closure *) - let closure = { name = fn; ty = Mtype.T_bytes } in - let fptr = new_temp Mtype.T_bytes in - Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size }); + if fn = !current_base_name then + let fptr = new_temp Mtype.T_bytes in + Vec.push tac (Load { rd = fptr; rs = !current_env; offset = 0; byte = pointer_size }); + + let args = args @ [!current_env] in + Vec.push tac (CallIndirect { rd; rs = fptr; args }); + else + (* Here `fn` is a closure *) + let closure = { name = fn; ty = Mtype.T_bytes } in + let fptr = new_temp Mtype.T_bytes in + Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size }); - (* Closure, along with environment, should be passed as argument *) - let args = args @ [closure] in - Vec.push tac (CallIndirect { rd; rs = fptr; args })); + (* Closure, along with environment, should be passed as argument *) + let args = args @ [closure] in + Vec.push tac (CallIndirect { rd; rs = fptr; args })); (* If this is a `Join`, then we must jump to the corresponding letfn *) if kind = Join then ( @@ -1081,6 +1092,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* This is a different function from the current one, *) (* so we must protect all global variables before generating body *) let this_fn = !current_function in + let this_base_fn = !current_base_name in let this_env = !current_env in let this_join = !current_join in let this_join_ret = !current_join_ret in @@ -1088,6 +1100,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Set the correct values for this new function *) let fn_name = Printf.sprintf "%s_closure_%s" !current_function name in current_function := fn_name; + current_base_name := name; current_env := fn_env; current_join := ""; current_join_ret := unit; @@ -1098,6 +1111,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Put them back *) current_function := this_fn; + current_base_name := this_base_fn; current_env := this_env; current_join := this_join; current_join_ret := this_join_ret; @@ -1122,7 +1136,7 @@ let rec do_convert tac (expr: Mcore.expr) = (* Store environment variables *) List.iter2 (fun (arg: var) offset -> let size = sizeof arg.ty in - if arg.name = !current_function then + if arg.name = !current_base_name then (* This closure captures myself, so I need to make myself a closure *) (* Fortunately my environment is just my closure *) Vec.push tac (Store { rd = !current_env; rs = closure; offset = offset - size; byte = size }) @@ -1720,6 +1734,7 @@ let convert_toplevel _start (top: Mcore.top_item) = let args = List.map var_of_param func.params in current_function := fn; + current_base_name := fn; let body = convert_expr func.body in (* Record the index of arguments that are traits *) diff --git a/src/riscv_opt_peephole.ml b/src/riscv_opt_peephole.ml index ce3d5e9..1eac881 100644 --- a/src/riscv_opt_peephole.ml +++ b/src/riscv_opt_peephole.ml @@ -116,6 +116,37 @@ let to_itype fn = in List.iter (fun block -> (block_of block).body <- convert block) blocks +let purity_table = Hashtbl.create 64 + +let rec is_pure fn = + match Hashtbl.find_opt purity_table fn with + | Some x -> x + | None -> + let global_vars = !Riscv_generate.global_vars in + let alloca_data = ref Stringset.empty in + let blocks = get_blocks fn in + let pure = + List.for_all (fun block -> + List.for_all (fun x -> + let is_global = ref false in + reg_iterd (fun var -> + if Stringset.mem var.name global_vars then + is_global := true; + var) x; + not !is_global + ) (body_of block) && + List.for_all (fun x -> match x with + | Alloca { rd } -> alloca_data := Stringset.add rd.name !alloca_data; true + | Call { fn = fn' } -> if fn <> fn' then is_pure fn' else false + | Store { rs } -> Stringset.mem rs.name !alloca_data + | CallExtern _ | CallIndirect _ + | Malloc _ -> false + | _ -> true) (body_of block) + ) blocks + in + Hashtbl.add purity_table fn pure; + pure + let remove_dead_variable fn = let blocks = get_blocks fn in let liveness = liveness_analysis fn in @@ -140,7 +171,8 @@ let remove_dead_variable fn = (* TODO: refine this, so that calls to pure functions are also eliminated *) match x with - | Call _ | CallExtern _ | CallIndirect _ -> true + | Call { fn } when not (is_pure fn) -> true + | CallExtern _ | CallIndirect _ -> true | _ -> !preserve ) body; |> Basic_vec.of_list diff --git a/src/riscv_ssa.ml b/src/riscv_ssa.ml index 9e03905..5622ad0 100644 --- a/src/riscv_ssa.ml +++ b/src/riscv_ssa.ml @@ -283,10 +283,17 @@ let to_string t = (* Deals with signedness: signed or unsigned *) (* In most R-type instructions, `rd` and the 2 `rs`es have the same type *) + + (* For unsigned comparison, `rd` may have different type *) + (* SSA interpreter relies on this function, so dealing with `most` is not enough *) let rtypeu op ({ rd; rs1; rs2 }: r_type) = - let width = (match rd.ty with - | T_uint | T_uint64 -> "u" - | _ -> "") in + let isu_rs1 = (match rs1.ty with + | T_uint | T_uint64 -> true + | _ -> false) in + let isu_rs2 = (match rs2.ty with + | T_uint | T_uint64 -> true + | _ -> false) in + let width = if (isu_rs1 && isu_rs2) then "u" else "" in Printf.sprintf "%s%s %s %s %s" op width rd.name rs1.name rs2.name in @@ -406,7 +413,7 @@ let to_string t = Printf.sprintf "li %s %s" rd.name (Int64.to_string imm) | AssignFP { rd; imm; } -> - Printf.sprintf "fli %s = %f" rd.name imm + Printf.sprintf "fli %s %f" rd.name imm | AssignLabel { rd; imm; } -> Printf.sprintf "la %s %s" rd.name imm diff --git a/src/riscv_virtasm_generate.ml b/src/riscv_virtasm_generate.ml index 24ee9e1..aace00e 100644 --- a/src/riscv_virtasm_generate.ml +++ b/src/riscv_virtasm_generate.ml @@ -375,26 +375,29 @@ let convert_single name body terminator (inst: Riscv_ssa.t) = terminator := Term.J (label_of label) | JumpIndirect { rs; possibilities } -> (* TODO: Optimizations on possibilities *) - terminator := Term.Jalr { - rd = Slot.Reg Zero; - rs1 = slot_v rs; - offset = 0; - }; + if List.length possibilities = 1 then + terminator := Term.J (label_of (List.hd possibilities)) + else + terminator := Term.Jalr { + rd = Slot.Reg Zero; + rs1 = slot_v rs; + offset = 0; + }; (* Floating point instructions *) + | FAdd _ -> () + | FSub _ -> () + | FMul _ -> () + | FDiv _ -> () + | FLess _ -> () + | FLeq _ -> () + | FGreat _ -> () + | FGeq _ -> () + | FEq _ -> () + | FNeq _ -> () + | FNeg _ -> () + | AssignFP _ -> () (* - | FAdd -> _ - | FSub -> _ - | FMul -> _ - | FDiv -> _ - | FLess -> _ - | FLeq -> _ - | FGreat -> _ - | FGeq -> _ - | FEq -> _ - | FNeq -> _ - | FNeg -> _ - | AssignFP -> _ | FnDecl -> _ | GlobalVarDecl -> _ | ExtArray _ -> _ *) diff --git a/test/interpreter.cpp b/test/interpreter.cpp index 43db751..745496c 100644 --- a/test/interpreter.cpp +++ b/test/interpreter.cpp @@ -19,7 +19,6 @@ std::map> blocks; std::map> fns; // Values of registers used when interpreting. -// TODO: currently no FP supported. std::map regs; std::map labels; @@ -64,13 +63,35 @@ int int_of(std::string s) { #define RTYPEU(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (uint64_t)x op (uint64_t)y; }) #define RTYPEUW(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (unsigned)x op (unsigned)y; }) #define ITYPEW(name, op) std::make_pair(name, [](int64_t x, int imm) -> int64_t { return (int)x op imm; }) +#define RTYPEF(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \ + double fx, fy; \ + std::memcpy(&fx, &x, 8); \ + std::memcpy(&fy, &y, 8); \ + double fres = fx op fy; \ + int64_t res; \ + std::memcpy(&res, &fres, 8); \ + return res; \ + }) +#define RTYPEFCMP(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \ + double fx, fy; \ + std::memcpy(&fx, &x, 8); \ + std::memcpy(&fy, &y, 8); \ + int64_t res = fx op fy; \ + return res; \ + }) #define VAL(i) regs[args[i]] #ifdef VERBOSE #define OUTPUT(name, value) std::cerr << "\t" << name << " = " << value << "\n\n" +#define OUTPUTF(name, value) \ + double fvalue; \ + std::memcpy(&fvalue, &value, 8); \ + std::cerr << "\t" << name << " = " << fvalue << "\n\n" + #define SAY(str) std::cerr << str << "\n" #else #define OUTPUT(name, value) +#define OUTPUTF(name, value) #define SAY(str) #endif @@ -118,6 +139,19 @@ int64_t interpret(std::string label) { RTYPEUW("moduw", %), }; + static std::map> rtypef = { + RTYPEF("fadd", +), + RTYPEF("fsub", -), + RTYPEF("fmul", *), + RTYPEF("fdiv", /), + RTYPEFCMP("fle", <), + RTYPEFCMP("fleq", <=), + RTYPEFCMP("fge", >), + RTYPEFCMP("fgeq", >=), + RTYPEFCMP("feq", ==), + RTYPEFCMP("fneq", !=), + }; + static std::map> load = { MEM("lb", char), MEM("lh", char16_t), @@ -158,6 +192,12 @@ int64_t interpret(std::string label) { continue; } + if (rtypef.contains(op)) { + VAL(1) = rtypef[op](VAL(2), VAL(3)); + OUTPUTF(args[1], VAL(1)); + continue; + } + if (load.contains(op)) { VAL(1) = load[op](VAL(2), int_of(args[3])); OUTPUT(args[1], VAL(1)); @@ -215,8 +255,17 @@ int64_t interpret(std::string label) { continue; } + if (op == "fneg") { + int64_t val = VAL(2); + double fval; + std::memcpy(&fval, &val, 8); + fval = -fval; + std::memcpy(&VAL(1), &fval, 8); + continue; + } + if (op == "not") { - VAL(1) = ~VAL(2); + VAL(1) = VAL(2) == 0; continue; } @@ -379,6 +428,18 @@ int64_t interpret(std::string label) { continue; } + if (op == "fli"){ + std::stringstream ss(args[2]); + double rs; + ss >> rs; + + int64_t rs_int; + std::memcpy(&rs_int, &rs, 8); + VAL(1) = rs_int; + OUTPUT(args[1], rs); + continue; + } + if (op == "mv") { VAL(1) = VAL(2); OUTPUT(args[1], VAL(1)); diff --git a/test/src/closure05/closure05.ans b/test/src/closure05/closure05.ans new file mode 100644 index 0000000..4f7108e --- /dev/null +++ b/test/src/closure05/closure05.ans @@ -0,0 +1,2 @@ +60 +() diff --git a/test/src/closure05/closure05.mbt b/test/src/closure05/closure05.mbt new file mode 100644 index 0000000..76f191b --- /dev/null +++ b/test/src/closure05/closure05.mbt @@ -0,0 +1,26 @@ +fn create(n: Int) -> (Int) -> (Int) -> (Int) -> Unit { + + fn layer1(x: Int) { + + fn layer2(y: Int) { + + fn layer3(z: Int) { + let _ = if z <= 0 { + println(x + y + n) + } else { + layer3(z - 1) + }; + }; + + layer3 + }; + + layer2 + }; + + layer1 +}; + +fn main { + println(create(10)(20)(30)(1)); +}; \ No newline at end of file diff --git a/test/src/floatpoint01/floatpoint01.ans b/test/src/floatpoint01/floatpoint01.ans new file mode 100644 index 0000000..cdb0386 --- /dev/null +++ b/test/src/floatpoint01/floatpoint01.ans @@ -0,0 +1,6 @@ +3.3 +-1.1 +2.42 +0.5 +[1.1, 2.2, 3.3, 4.4, 5.5] +-3.14 diff --git a/test/src/floatpoint01/floatpoint01.mbt b/test/src/floatpoint01/floatpoint01.mbt new file mode 100644 index 0000000..74e69ca --- /dev/null +++ b/test/src/floatpoint01/floatpoint01.mbt @@ -0,0 +1,13 @@ +fn f(x : Double) -> Double { + -x +} + +fn main { + println(1.1 + 2.2) + println(1.1 - 2.2) + println(1.1 * 2.2) + println(1.1 / 2.2) + let arr1 = [1.1, 2.2, 3.3, 4.4, 5.5] + println(arr1) + println(f(3.14)) +} \ No newline at end of file diff --git a/test/src/floatpoint02/floatpoint02.ans b/test/src/floatpoint02/floatpoint02.ans new file mode 100644 index 0000000..17607c6 --- /dev/null +++ b/test/src/floatpoint02/floatpoint02.ans @@ -0,0 +1,6 @@ +4 +3 +7 +6 +1 +9 diff --git a/test/src/floatpoint02/floatpoint02.mbt b/test/src/floatpoint02/floatpoint02.mbt new file mode 100644 index 0000000..311f4a0 --- /dev/null +++ b/test/src/floatpoint02/floatpoint02.mbt @@ -0,0 +1,37 @@ +fn main { + let mut x : Double = 0 + let mut t_geq : Int = 0 + let mut t_ge : Int = 0 + let mut t_leq : Int = 0 + let mut t_le : Int = 0 + let mut t_eq : Int = 0 + let mut t_neq : Int = 0 + + for i = 0; i < 10; i = i + 1 { + x += 0.1 + if (x <= 0.4) { + t_leq += 1 + } + if (x < 0.4) { + t_le += 1 + } + if (x >= 0.4) { + t_geq += 1 + } + if (x > 0.4) { + t_ge += 1 + } + if (x == 0.4) { + t_eq += 1 + } + if (x != 0.4) { + t_neq += 1 + } + } + println(t_leq) + println(t_le) + println(t_geq) + println(t_ge) + println(t_eq) + println(t_neq) +} \ No newline at end of file