Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/riscv_generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ let global_vars = ref Stringset.empty
(** The function/closure we're currently dealing with *)
let current_function = ref ""

(** The origin current function name. Used for judge whether a closure calls itself *)
let current_base_name = ref ""

(** The environment (i.e. the pointer passed to the closure) of current closure *)
let current_env = ref unit

Expand Down Expand Up @@ -399,7 +402,8 @@ let deal_with_prim tac rd (prim: Primitive.prim) args =

let altered = new_temp Mtype.T_string in
Vec.push tac (Add { rd = altered; rs1 = str; rs2 = i });
Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 })
Vec.push tac (Store { rd = item; rs = altered; offset = 0; byte = 1 });
Vec.push tac (Assign { rd; rs = unit })

(* Be cautious that each `char` is 2 bytes long, which is extremely counter-intuitive. *)
| Pgetstringitem ->
Expand Down Expand Up @@ -806,14 +810,21 @@ let rec do_convert tac (expr: Mcore.expr) =
(if Stringset.mem fn !fn_names then
Vec.push tac (Call { rd; fn; args })
else
(* Here `fn` is a closure *)
let closure = { name = fn; ty = Mtype.T_bytes } in
let fptr = new_temp Mtype.T_bytes in
Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size });
if fn = !current_base_name then
let fptr = new_temp Mtype.T_bytes in
Vec.push tac (Load { rd = fptr; rs = !current_env; offset = 0; byte = pointer_size });

let args = args @ [!current_env] in
Vec.push tac (CallIndirect { rd; rs = fptr; args });
else
(* Here `fn` is a closure *)
let closure = { name = fn; ty = Mtype.T_bytes } in
let fptr = new_temp Mtype.T_bytes in
Vec.push tac (Load { rd = fptr; rs = closure; offset = 0; byte = pointer_size });

(* Closure, along with environment, should be passed as argument *)
let args = args @ [closure] in
Vec.push tac (CallIndirect { rd; rs = fptr; args }));
(* Closure, along with environment, should be passed as argument *)
let args = args @ [closure] in
Vec.push tac (CallIndirect { rd; rs = fptr; args }));

(* If this is a `Join`, then we must jump to the corresponding letfn *)
if kind = Join then (
Expand Down Expand Up @@ -1081,13 +1092,15 @@ let rec do_convert tac (expr: Mcore.expr) =
(* This is a different function from the current one, *)
(* so we must protect all global variables before generating body *)
let this_fn = !current_function in
let this_base_fn = !current_base_name in
let this_env = !current_env in
let this_join = !current_join in
let this_join_ret = !current_join_ret in

(* Set the correct values for this new function *)
let fn_name = Printf.sprintf "%s_closure_%s" !current_function name in
current_function := fn_name;
current_base_name := name;
current_env := fn_env;
current_join := "";
current_join_ret := unit;
Expand All @@ -1098,6 +1111,7 @@ let rec do_convert tac (expr: Mcore.expr) =

(* Put them back *)
current_function := this_fn;
current_base_name := this_base_fn;
current_env := this_env;
current_join := this_join;
current_join_ret := this_join_ret;
Expand All @@ -1122,7 +1136,7 @@ let rec do_convert tac (expr: Mcore.expr) =
(* Store environment variables *)
List.iter2 (fun (arg: var) offset ->
let size = sizeof arg.ty in
if arg.name = !current_function then
if arg.name = !current_base_name then
(* This closure captures myself, so I need to make myself a closure *)
(* Fortunately my environment is just my closure *)
Vec.push tac (Store { rd = !current_env; rs = closure; offset = offset - size; byte = size })
Expand Down Expand Up @@ -1720,6 +1734,7 @@ let convert_toplevel _start (top: Mcore.top_item) =
let args = List.map var_of_param func.params in

current_function := fn;
current_base_name := fn;
let body = convert_expr func.body in

(* Record the index of arguments that are traits *)
Expand Down
34 changes: 33 additions & 1 deletion src/riscv_opt_peephole.ml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,37 @@ let to_itype fn =
in
List.iter (fun block -> (block_of block).body <- convert block) blocks

let purity_table = Hashtbl.create 64

let rec is_pure fn =
match Hashtbl.find_opt purity_table fn with
| Some x -> x
| None ->
let global_vars = !Riscv_generate.global_vars in
let alloca_data = ref Stringset.empty in
let blocks = get_blocks fn in
let pure =
List.for_all (fun block ->
List.for_all (fun x ->
let is_global = ref false in
reg_iterd (fun var ->
if Stringset.mem var.name global_vars then
is_global := true;
var) x;
not !is_global
) (body_of block) &&
List.for_all (fun x -> match x with
| Alloca { rd } -> alloca_data := Stringset.add rd.name !alloca_data; true
| Call { fn = fn' } -> if fn <> fn' then is_pure fn' else false
| Store { rs } -> Stringset.mem rs.name !alloca_data
| CallExtern _ | CallIndirect _
| Malloc _ -> false
| _ -> true) (body_of block)
) blocks
in
Hashtbl.add purity_table fn pure;
pure

let remove_dead_variable fn =
let blocks = get_blocks fn in
let liveness = liveness_analysis fn in
Expand All @@ -140,7 +171,8 @@ let remove_dead_variable fn =

(* TODO: refine this, so that calls to pure functions are also eliminated *)
match x with
| Call _ | CallExtern _ | CallIndirect _ -> true
| Call { fn } when not (is_pure fn) -> true
| CallExtern _ | CallIndirect _ -> true
| _ -> !preserve
) body;
|> Basic_vec.of_list
Expand Down
15 changes: 11 additions & 4 deletions src/riscv_ssa.ml
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,17 @@ let to_string t =

(* Deals with signedness: signed or unsigned *)
(* In most R-type instructions, `rd` and the 2 `rs`es have the same type *)

(* For unsigned comparison, `rd` may have different type *)
(* SSA interpreter relies on this function, so dealing with `most` is not enough *)
let rtypeu op ({ rd; rs1; rs2 }: r_type) =
let width = (match rd.ty with
| T_uint | T_uint64 -> "u"
| _ -> "") in
let isu_rs1 = (match rs1.ty with
| T_uint | T_uint64 -> true
| _ -> false) in
let isu_rs2 = (match rs2.ty with
| T_uint | T_uint64 -> true
| _ -> false) in
let width = if (isu_rs1 && isu_rs2) then "u" else "" in
Printf.sprintf "%s%s %s %s %s" op width rd.name rs1.name rs2.name
in

Expand Down Expand Up @@ -406,7 +413,7 @@ let to_string t =
Printf.sprintf "li %s %s" rd.name (Int64.to_string imm)

| AssignFP { rd; imm; } ->
Printf.sprintf "fli %s = %f" rd.name imm
Printf.sprintf "fli %s %f" rd.name imm

| AssignLabel { rd; imm; } ->
Printf.sprintf "la %s %s" rd.name imm
Expand Down
37 changes: 20 additions & 17 deletions src/riscv_virtasm_generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -375,26 +375,29 @@ let convert_single name body terminator (inst: Riscv_ssa.t) =
terminator := Term.J (label_of label)

| JumpIndirect { rs; possibilities } -> (* TODO: Optimizations on possibilities *)
terminator := Term.Jalr {
rd = Slot.Reg Zero;
rs1 = slot_v rs;
offset = 0;
};
if List.length possibilities = 1 then
terminator := Term.J (label_of (List.hd possibilities))
else
terminator := Term.Jalr {
rd = Slot.Reg Zero;
rs1 = slot_v rs;
offset = 0;
};

(* Floating point instructions *)
| FAdd _ -> ()
| FSub _ -> ()
| FMul _ -> ()
| FDiv _ -> ()
| FLess _ -> ()
| FLeq _ -> ()
| FGreat _ -> ()
| FGeq _ -> ()
| FEq _ -> ()
| FNeq _ -> ()
| FNeg _ -> ()
| AssignFP _ -> ()
(*
| FAdd -> _
| FSub -> _
| FMul -> _
| FDiv -> _
| FLess -> _
| FLeq -> _
| FGreat -> _
| FGeq -> _
| FEq -> _
| FNeq -> _
| FNeg -> _
| AssignFP -> _
| FnDecl -> _
| GlobalVarDecl -> _
| ExtArray _ -> _ *)
Expand Down
65 changes: 63 additions & 2 deletions test/interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ std::map<std::string, std::vector<std::string>> blocks;
std::map<std::string, std::vector<std::string>> fns;

// Values of registers used when interpreting.
// TODO: currently no FP supported.
std::map<std::string, int64_t> regs;
std::map<std::string, int64_t> labels;

Expand Down Expand Up @@ -64,13 +63,35 @@ int int_of(std::string s) {
#define RTYPEU(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (uint64_t)x op (uint64_t)y; })
#define RTYPEUW(name, op) std::make_pair(name, [](int64_t x, int64_t y) { return (unsigned)x op (unsigned)y; })
#define ITYPEW(name, op) std::make_pair(name, [](int64_t x, int imm) -> int64_t { return (int)x op imm; })
#define RTYPEF(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \
double fx, fy; \
std::memcpy(&fx, &x, 8); \
std::memcpy(&fy, &y, 8); \
double fres = fx op fy; \
int64_t res; \
std::memcpy(&res, &fres, 8); \
return res; \
})
#define RTYPEFCMP(name, op) std::make_pair(name, [](int64_t x, int64_t y) { \
double fx, fy; \
std::memcpy(&fx, &x, 8); \
std::memcpy(&fy, &y, 8); \
int64_t res = fx op fy; \
return res; \
})
#define VAL(i) regs[args[i]]

#ifdef VERBOSE
#define OUTPUT(name, value) std::cerr << "\t" << name << " = " << value << "\n\n"
#define OUTPUTF(name, value) \
double fvalue; \
std::memcpy(&fvalue, &value, 8); \
std::cerr << "\t" << name << " = " << fvalue << "\n\n"

#define SAY(str) std::cerr << str << "\n"
#else
#define OUTPUT(name, value)
#define OUTPUTF(name, value)
#define SAY(str)
#endif

Expand Down Expand Up @@ -118,6 +139,19 @@ int64_t interpret(std::string label) {
RTYPEUW("moduw", %),
};

static std::map<std::string, std::function<int64_t (int64_t, int64_t)>> rtypef = {
RTYPEF("fadd", +),
RTYPEF("fsub", -),
RTYPEF("fmul", *),
RTYPEF("fdiv", /),
RTYPEFCMP("fle", <),
RTYPEFCMP("fleq", <=),
RTYPEFCMP("fge", >),
RTYPEFCMP("fgeq", >=),
RTYPEFCMP("feq", ==),
RTYPEFCMP("fneq", !=),
};

static std::map<std::string, std::function<int64_t (int64_t, int)>> load = {
MEM("lb", char),
MEM("lh", char16_t),
Expand Down Expand Up @@ -158,6 +192,12 @@ int64_t interpret(std::string label) {
continue;
}

if (rtypef.contains(op)) {
VAL(1) = rtypef[op](VAL(2), VAL(3));
OUTPUTF(args[1], VAL(1));
continue;
}

if (load.contains(op)) {
VAL(1) = load[op](VAL(2), int_of(args[3]));
OUTPUT(args[1], VAL(1));
Expand Down Expand Up @@ -215,8 +255,17 @@ int64_t interpret(std::string label) {
continue;
}

if (op == "fneg") {
int64_t val = VAL(2);
double fval;
std::memcpy(&fval, &val, 8);
fval = -fval;
std::memcpy(&VAL(1), &fval, 8);
continue;
}

if (op == "not") {
VAL(1) = ~VAL(2);
VAL(1) = VAL(2) == 0;
continue;
}

Expand Down Expand Up @@ -379,6 +428,18 @@ int64_t interpret(std::string label) {
continue;
}

if (op == "fli"){
std::stringstream ss(args[2]);
double rs;
ss >> rs;

int64_t rs_int;
std::memcpy(&rs_int, &rs, 8);
VAL(1) = rs_int;
OUTPUT(args[1], rs);
continue;
}

if (op == "mv") {
VAL(1) = VAL(2);
OUTPUT(args[1], VAL(1));
Expand Down
2 changes: 2 additions & 0 deletions test/src/closure05/closure05.ans
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
60
()
26 changes: 26 additions & 0 deletions test/src/closure05/closure05.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
fn create(n: Int) -> (Int) -> (Int) -> (Int) -> Unit {

fn layer1(x: Int) {

fn layer2(y: Int) {

fn layer3(z: Int) {
let _ = if z <= 0 {
println(x + y + n)
} else {
layer3(z - 1)
};
};

layer3
};

layer2
};

layer1
};

fn main {
println(create(10)(20)(30)(1));
};
6 changes: 6 additions & 0 deletions test/src/floatpoint01/floatpoint01.ans
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
3.3
-1.1
2.42
0.5
[1.1, 2.2, 3.3, 4.4, 5.5]
-3.14
13 changes: 13 additions & 0 deletions test/src/floatpoint01/floatpoint01.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
fn f(x : Double) -> Double {
-x
}

fn main {
println(1.1 + 2.2)
println(1.1 - 2.2)
println(1.1 * 2.2)
println(1.1 / 2.2)
let arr1 = [1.1, 2.2, 3.3, 4.4, 5.5]
println(arr1)
println(f(3.14))
}
6 changes: 6 additions & 0 deletions test/src/floatpoint02/floatpoint02.ans
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
4
3
7
6
1
9
Loading
Loading