From 9d35c3a1d48c4a42af4ddd9410e7f03655edb453 Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Thu, 18 Sep 2025 13:12:26 +0800 Subject: [PATCH 1/2] Bugfixs for reg_spill --- src/riscv_reg_spill.ml | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/riscv_reg_spill.ml b/src/riscv_reg_spill.ml index 4b0f2ff..b277dbb 100644 --- a/src/riscv_reg_spill.ml +++ b/src/riscv_reg_spill.ml @@ -301,6 +301,9 @@ let initUsual_entryW let entryW_I = split_init_entryW freqs_I cands_I entryNextUse_I k_I in let entryW_F = split_init_entryW freqs_F cands_F entryNextUse_F k_F in let entryW = SlotSet.union entryW_I entryW_F in + let entryW = SlotSet.filter entryW (fun var -> match var with + | Slot.Slot _ | Slot.FSlot _ -> true + | _ -> false) in update_spillinfo bl { binfo with entryW }; () ;; @@ -335,7 +338,7 @@ let init_entryW (bl : VBlockLabel.t) (entryNextUse : int SlotMap.t) = Vec.push normal_pred pred_bl; let pred_info = get_spillinfo pred_bl in SlotSet.iter pred_info.exitW (fun var -> - match SlotMap.find_opt b_liveinfo.exitNextUse var with + match SlotMap.find_opt b_liveinfo.entryNextUse var with | Some dist -> if not_inf dist then ( @@ -372,6 +375,9 @@ let init_entryS (bl : VBlockLabel.t) = entryS := SlotSet.union !entryS pred_info.exitS) preds; entryS := SlotSet.inter !entryS binfo.entryW; + entryS := SlotSet.filter !entryS (fun var -> match var with + | Slot.Slot _ | Slot.FSlot _ -> true + | _ -> false); update_spillinfo bl { binfo with entryS = !entryS }; () ;; @@ -393,6 +399,9 @@ let append_spill_to_pred_block (bl : VBlockLabel.t) = (* a-Reload instructions at the end of the predecessor block *) let need_reload = SlotSet.diff entryW pred_info.exitW in + let need_reload = SlotSet.filter need_reload (fun var -> match var with + | Slot.Slot _ | Slot.FSlot _ -> true + | _ -> false) in let reload_insts = generate_trailing_reload need_reload in Vec.append block.body reload_insts; @@ -424,9 +433,9 @@ let limit_func (fun var -> if (not (SlotSet.mem !s var)) - && SlotMap.find_default nextUse.![i] var max_int < max_int + (* && SlotMap.find_default nextUse.![i] var max_int < max_int *) then spill := SlotSet.add !spill var; - s := SlotSet.remove !s var; + s := SlotSet.add !s var; w := SlotSet.remove !w var; nextUse.![i] <- SlotMap.remove nextUse.![i] var) (list_drop k sorted_w) @@ -462,6 +471,10 @@ let apply_min_algorithm (bl : VBlockLabel.t) (nextUse : int SlotMap.t Vec.t) = (adjust_k : int) = let reload = ref (SlotSet.diff srcs !w) in + (* reload := SlotSet.filter !reload (fun var -> (Option.is_none @@ SlotMap.find_opt !reg_map var)); *) + reload := SlotSet.filter !reload (fun var -> match var with + | Slot.Slot _ | Slot.FSlot _ -> true + | _ -> false); let spill = ref SlotSet.empty in let protected = srcs in (* At this point, protected protects the registers being used *) @@ -469,7 +482,7 @@ let apply_min_algorithm (bl : VBlockLabel.t) (nextUse : int SlotMap.t Vec.t) = (* a. Compute the variables that need to be reloaded *) SlotSet.iter !reload (fun var -> w := SlotSet.add !w var; - s := SlotSet.add !s var); + s := SlotSet.remove !s var); (* Adjust the maximum number of allocatable registers *) (* b. Leave registers for src *) @@ -480,8 +493,8 @@ let apply_min_algorithm (bl : VBlockLabel.t) (nextUse : int SlotMap.t Vec.t) = (* Further reduce k *) (* d. Add defs to w_I *) - SlotSet.iter dests (fun var -> w := SlotSet.add !w var); - let protected = dests in + SlotSet.iter dests (fun var -> w := SlotSet.add !w var; s := SlotSet.remove !s var); + let protected = SlotSet.union dests protected in (* At this point, protected protects the registers being defined *) if i <> body_size then limit_func nextUse w s spill protected (i + 1) adjust_k; From 637433f78aa4cc8797f28ab3fd91403cf64c5a6d Mon Sep 17 00:00:00 2001 From: bmwangmh Date: Thu, 18 Sep 2025 19:57:10 +0800 Subject: [PATCH 2/2] feat: Register allocation --- src/riscv_reg.ml | 8 +- src/riscv_reg_alloc.ml | 423 ++++++++++++++++++++++++++++++++++++++++- src/riscv_reg_util.ml | 5 +- src/riscv_virtasm.ml | 132 ++++++++++++- 4 files changed, 557 insertions(+), 11 deletions(-) diff --git a/src/riscv_reg.ml b/src/riscv_reg.ml index ff5c808..40ec4de 100644 --- a/src/riscv_reg.ml +++ b/src/riscv_reg.ml @@ -74,12 +74,18 @@ module Reg = struct ;; (* Reg寄存器最大可分配数量*) - let k = 32 + let k = 11 (* 用于调用者保存寄存器*) let caller_saved_regs = [ Ra; T0; T1; T2; A0; A1; A2; A3; A4; A5; A6; A7; T3; T4; T5; T6 ] ;; + let callee_saved_regs = + [ S1; S2; S3; S4; S5; S6; S7; S8; S9; S10; S11 ] + ;; + + let spill_reg = T0 + ;; end (* Module for floating-point registers (freg_t) *) diff --git a/src/riscv_reg_alloc.ml b/src/riscv_reg_alloc.ml index 74fca3d..72da628 100644 --- a/src/riscv_reg_alloc.ml +++ b/src/riscv_reg_alloc.ml @@ -2,11 +2,424 @@ open Riscv_reg open Riscv_virtasm open Riscv_reg_util +module Spill = Riscv_reg_spill -let reg_alloc (vprog: VProg.t) = - let rpo = RPO.calculate_rpo vprog in - Riscv_reg_spill.spill_regs vprog rpo; +let rpo = ref RPO.empty + +let vprog = ref VProg.empty + +module DisjointSet = struct + type 'a t = ('a, ('a * int) ref) Hashtbl.t + + let create size = Hashtbl.create size + + let make_set dsu element = + if not (Hashtbl.mem dsu element) then + Hashtbl.add dsu element (ref (element, 0)) + + let rec find dsu element = + let parent, _ = !(Hashtbl.find dsu element) in + if parent = element then + Hashtbl.find dsu element + else + let root_ref = find dsu parent in + Hashtbl.find dsu element := !root_ref; + root_ref + + let union dsu elem1 elem2 = + let root1_ref = find dsu elem1 in + let root2_ref = find dsu elem2 in + let (root1, rank1) = !root1_ref in + let (root2, rank2) = !root2_ref in + if root1 <> root2 then + if rank1 > rank2 then + root2_ref := (root1, rank2) + else begin + root1_ref := (root2, rank1); + if rank1 = rank2 then + root2_ref := (root2, rank2 + 1) + end + + let are_connected dsu elem1 elem2 = + try + let root1, _ = !(find dsu elem1) in + let root2, _ = !(find dsu elem2) in + root1 = root2 + with Not_found -> false +end + +module AllocEnv = struct + (** + Data structure for data-flow analysis. + Here W stands for working registers (those aren't spilled); + S stands for spilled registers. + *) + type alloc_info = + { entry_map : Slot.t SlotMap.t + ; exit_map : Slot.t SlotMap.t + } + + type t = alloc_info VBlockMap.t + + let empty_info : alloc_info = + { entry_map = SlotMap.empty + ; exit_map = SlotMap.empty + } + ;; + + let empty : t = VBlockMap.empty + + let get_allocinfo (alloc_env : t) (bl : VBlockLabel.t) : alloc_info = + match VBlockMap.find_opt alloc_env bl with + | Some x -> x + | None -> empty_info + ;; + + let update_allocinfo (alloc_env : t) (bl : VBlockLabel.t) (info : alloc_info) : t = + VBlockMap.add alloc_env bl info + ;; +end + +let alloc_env : AllocEnv.t ref = ref AllocEnv.empty + +let available_regs = SlotSet.of_list (List.map (fun r -> Slot.Reg r) Reg.callee_saved_regs) + +let get_allocinfo (bl : VBlockLabel.t) : AllocEnv.alloc_info = + AllocEnv.get_allocinfo !alloc_env bl +;; + +let update_allocinfo (bl : VBlockLabel.t) (info : AllocEnv.alloc_info) : unit = + alloc_env := AllocEnv.update_allocinfo !alloc_env bl info +;; + +let incr_freq_by_one (freqs : (int SlotMap.t) SlotMap.t) (slot : Slot.t) (reg : Slot.t) : (int SlotMap.t) SlotMap.t = + match SlotMap.find_opt freqs slot with + | Some mp -> + (match SlotMap.find_opt mp reg with + | Some x -> SlotMap.add freqs slot (SlotMap.add mp reg (x + 1)) + | None -> SlotMap.add freqs slot (SlotMap.add mp reg 1)) + | None -> SlotMap.add freqs slot (SlotMap.singleton reg 1) +;; + +(* Helper function : Find the register used with the maximum frequency in predecessors *) +let find_max_freq (freq_map_opt : int SlotMap.t option) : Slot.t option = + Option.bind freq_map_opt (fun freq_map -> + if SlotMap.is_empty freq_map then + None + else + let _, max_reg = + SlotMap.fold freq_map (0, Slot.Unit) (fun key freq (max_so_far, reg_so_far) -> + if freq > max_so_far then + (freq, key) + else + (max_so_far, reg_so_far)) + in + Some max_reg) + +(* 1. Allocate register for the entry part. + For each variable, simply use the most frequent register from predecessors. + Since a block may not be the begining of the loop back edge, for loop back edge predecessors, force them to use the same register. +*) +(* TODO: Optimize choosing strategy *) +let alloc_entry (bl : VBlockLabel.t) = + let binfo = Spill.get_spillinfo bl in + let rinfo = get_allocinfo bl in + let block = VProg.get_block !vprog bl in + let freq : (int SlotMap.t) SlotMap.t ref = ref SlotMap.empty in + + (* The block maybe a begining of a loop back edge and some of the registers may be destined *) + let reg_map = ref rinfo.entry_map in + + (* 1. Count frequencies *) + List.iter + (fun pred -> + match pred with + | VBlock.NormalEdge pred_bl -> + let pred_info = get_allocinfo pred_bl in + SlotMap.iter pred_info.exit_map (fun var reg -> + freq := incr_freq_by_one !freq var reg; + ()) + | VBlock.LoopBackEdge _ -> () + ) block.preds; + + (* 2. Allocate rest variable *) + (* Gather unallocated variables *) + let unalloc = ref SlotSet.empty in + SlotSet.iter binfo.entryW + (fun var -> + if Option.is_none @@ SlotMap.find_opt !reg_map var then + match find_max_freq (SlotMap.find_opt !freq var) with + | Some reg -> reg_map := SlotMap.add !reg_map var reg; + | None -> unalloc := SlotSet.add !unalloc var; + ); + (* Allocate them *) + let reg_used = SlotMap.fold !reg_map SlotSet.empty (fun _ reg used -> SlotSet.add used reg) in + let reg_left = SlotSet.diff available_regs reg_used in + let _ = SlotSet.fold !unalloc reg_left + (fun var reg_left -> + let reg = SlotSet.choose reg_left in + reg_map := SlotMap.add !reg_map var reg; + SlotSet.remove reg_left reg + ) in + + (* 3. Destine registers for loop variables *) + List.iter + (fun pred -> + match pred with + | VBlock.LoopBackEdge pred_bl -> + let pred_alloc = get_allocinfo pred_bl in + if pred_alloc.entry_map <> SlotMap.empty && pred_bl.name <> bl.name then failwith (Printf.sprintf "reg_alloc : Multiple Loopback edge %s" (bl.name)); + if pred_bl.name <> bl.name then( + let pred_info = Spill.get_spillinfo pred_bl in + let pred_map = SlotSet.fold pred_info.entryW SlotMap.empty (fun var pred_map -> + (match SlotMap.find_opt !reg_map var with + | Some reg -> SlotMap.add pred_map var reg + | None -> pred_map) + ) in + update_allocinfo pred_bl { pred_alloc with entry_map = pred_map }) + | VBlock.NormalEdge pred_bl -> () + ) block.preds; + update_allocinfo bl { rinfo with entry_map = !reg_map } + +(* 2. Allocate registers for the body of the block *) +let alloc_body (bl : VBlockLabel.t) = + let rinfo = get_allocinfo bl in + let block = VProg.get_block !vprog bl in + let reg_map = ref rinfo.entry_map in + let reg_used = SlotMap.fold !reg_map SlotSet.empty (fun _ reg used -> SlotSet.add used reg) in + let reg_left_init = SlotSet.diff available_regs reg_used in + let reg_left = ref reg_left_init in + reg_map := SlotMap.add !reg_map Slot.Unit Slot.Unit; + reg_map := SlotMap.add !reg_map (Slot.Reg Zero) (Slot.Reg Zero); + + (* Reproduce spill information by Reload and Spill instructions *) + (* Allocate for new produced variables *) + Vec.iteri block.body (fun i inst -> + match inst with + | Reload { origin } (*| FReload { origin }*) -> + let reg = SlotSet.choose !reg_left in + reg_map := SlotMap.add !reg_map origin reg; + reg_left := SlotSet.remove !reg_left reg; + Vec.set block.body i (Reload { target = reg ; origin }) + | Spill { origin } (*| FSpill { origin }*) -> + let reg = SlotMap.find_exn !reg_map origin in + reg_map := SlotMap.remove !reg_map origin; + reg_left := SlotSet.add !reg_left reg; + Vec.set block.body i (Spill { target = reg ; origin }) + | _ -> + let dests = Inst.get_dests inst in + List.iter (fun var -> match SlotMap.find_opt !reg_map var with + | Some _ -> () + | None -> + let reg = SlotSet.choose !reg_left in + reg_map := SlotMap.add !reg_map var reg; + reg_left := SlotSet.remove !reg_left reg; + ) dests; + Vec.set block.body i (Inst.inst_convert inst (fun slot -> SlotMap.find_exn !reg_map slot)) + ); + + update_allocinfo bl { rinfo with exit_map = !reg_map }; + () + +(* Record new successors for blocks with multiple successors *) +let conv_map : VBlockLabel.t VBlockMap.t VBlockMap.t ref = ref VBlockMap.empty + +(* Helper function : Solve the conflict with a normal edge by generate a series of move instructions *) +(* Consider a variable in reg1 in predecessor and reg2 in current block, we need a mv reg1 reg2 for it. + However, if reg2 is also used in predecessor, we need to make sure the mv instructions are executed in the right order. + Build a graph for the moves, add a edge from rs to rd for each move. + Since each register can only be used once in predecessor and current block, the graph is a collection of chains and cycles. + For each chain, we can simply execute the moves from tail to head. + For each cycle, we need a temporary register to break the cycle. + Here we use the spill register as the temporary register, since it is guaranteed to be free at this point. +*) +let update_normal_pred (bl : VBlockLabel.t) (pred : VBlockLabel.t) = + let binfo = get_allocinfo bl in + let pinfo = get_allocinfo pred in + let addInst = Vec.empty () in + let values map = SlotMap.fold map [] (fun _ v acc -> v :: acc) in + + (* Series of data structures to represent the graph *) + (* A disjoint set to represent the connected components *) + let move_graph = DisjointSet.create 16 in + (* A set to represent the connected components *) + let is_circle = ref SlotSet.empty in + (* A map to represent the tail of each connected component *) + (* Will also record all connected nodes since a disjoint set can only represent connectivity *) + let tail = ref SlotMap.empty in + (* A map to represent the predecessor of each node *) + let pred_map = ref SlotMap.empty in + (* A function to get the representative of a connected component *) + let get_fa = fun var -> fst !(DisjointSet.find move_graph var) in + + (* Initialize the data structures *) + List.iter (fun var -> + DisjointSet.make_set move_graph var; + tail := SlotMap.add !tail var var; + ) (List.append (values binfo.entry_map) (values pinfo.exit_map)); + + (* Build the graph *) + SlotMap.iter pinfo.exit_map (fun var preg -> + match SlotMap.find_opt binfo.entry_map var with + | Some reg -> if preg <> reg then + (if DisjointSet.are_connected move_graph preg reg then ( + (* When the edge form a circle, no more edges will be added *) + is_circle := SlotSet.add !is_circle (get_fa preg); + pred_map := SlotMap.add !pred_map reg preg + ) else ( + let new_tail = SlotMap.find_exn !tail (get_fa reg) in + tail := SlotMap.remove !tail (get_fa reg); + tail := SlotMap.remove !tail (get_fa preg); + DisjointSet.union move_graph preg reg; + tail := SlotMap.add !tail (get_fa preg) new_tail; + pred_map := SlotMap.add !pred_map reg preg; + ) + ) + | None -> () + ); + + (* Generate the move instructions *) + SlotMap.iter !tail (fun fa var_tail -> + if SlotSet.mem !is_circle fa then ( + Vec.push addInst (Inst.Mv { rd = Slot.Reg Reg.spill_reg ; rs = fa }); + let rec loop var = + if var <> fa then ( + let next = SlotMap.find_exn !pred_map var in + Vec.push addInst (Inst.Mv { rd = var ; rs = next }); + if next <> fa then loop next + else Vec.push addInst (Inst.Mv { rd = next ; rs = Slot.Reg Reg.spill_reg }) + ) + in + loop fa + ) else ( + let rec loop var = + match SlotMap.find_opt !pred_map var with + | Some next -> + Vec.push addInst (Inst.Mv { rd = var ; rs = next }); + loop next + | None -> () in + loop var_tail + ) + ); + addInst + +(* 3. Solve the problem of conflicts with predecessors and convert terminator *) +(* For predecessors with single successor, insert move instructions at the end of it. This may change the registers in terminator. (case 1) + For others, if the current block has only one predecessor, insert move instructions at the begining of current block. (case 2) + Otherwise, we have no choice but to create a new block to hold the move instructions. (case 3) +*) +let solve_edge (bl : VBlockLabel.t) = + let block = VProg.get_block !vprog bl in + let preds = List.map (fun pred -> + match pred with + | VBlock.NormalEdge pred_bl -> + let pred_block = VProg.get_block !vprog pred_bl in + let pred_info = get_allocinfo pred_bl in + let reg_map = pred_info.exit_map in + let addInst = update_normal_pred bl pred_bl in + let convert_term term = Term.term_map_reg term (fun slot -> SlotMap.find_exn reg_map slot) in + if List.length @@ VBlock.get_successors pred_block = 1 then ( + (* Case 1 *) + Vec.append pred_block.body addInst; + (* Calculate effect of the moves *) + let move_map = ref SlotMap.empty in + let spill_match = ref None in + Vec.iter (fun inst -> + match inst with + | Inst.Mv { rd ; rs } -> ( + if rd = Slot.Reg Reg.spill_reg then ( + match !spill_match with + | Some _ -> failwith "reg_alloc: multiple use of spill_reg in edge" + | None -> spill_match := Some rs + ) else ( + if rs = Slot.Reg Reg.spill_reg then ( + match !spill_match with + | Some reg -> move_map := SlotMap.add !move_map reg rd + | None -> failwith "reg_alloc: use of spill_reg without match in edge" + ) else ( + move_map := SlotMap.add !move_map rs rd + ) + ); + ) + | _ -> failwith "reg_alloc: unexpected instruction in edge solving" + ) addInst; + let convert_term_moved term = Term.term_map_reg term (fun slot -> Option.value (SlotMap.find_opt !move_map slot) ~default:slot) in + vprog := VProg.update_block !vprog pred_bl { pred_block with term = pred_block.term |> convert_term |> convert_term_moved }; + pred + ) else ( + if List.length block.preds = 1 then ( + (* Case 2 *) + Vec.append addInst block.body; + Vec.clear block.body; + Vec.append block.body addInst; + conv_map := VBlockMap.add !conv_map pred_bl @@ VBlockMap.add (VBlockMap.find_default !conv_map pred_bl VBlockMap.empty) bl bl; + pred + ) else ( + (* Case 3 *) + let new_block_label = VBlockLabel.fresh "edge" in + let term = Term.J bl in + let preds = [VBlock.NormalEdge pred_bl] in + let new_block : VBlock.t = { body = addInst; term; preds } in + conv_map := VBlockMap.add !conv_map pred_bl @@ VBlockMap.add (VBlockMap.find_default !conv_map pred_bl VBlockMap.empty) bl new_block_label; + vprog := VProg.update_block !vprog new_block_label new_block; + VBlock.NormalEdge new_block_label + ) + ) + | VBlock.LoopBackEdge _ -> pred + ) block.preds in + vprog := VProg.update_block !vprog bl { block with preds }; + () + +(* Main function: used to allocate register for a block and solve its conflict with predecessors *) +let alloc_block (bl : VBlockLabel.t) = + (* 1. Allocate register for the entry part.*) + alloc_entry bl; + + (* 2. Allocate registers for the body of the block *) + alloc_body bl; + + (* 3. Solve the problem of conflicts with predecessors and convert terminator *) + solve_edge bl + +(* 4. Handle the problem of unprocessed successors -- + Since the algorithm runs in a rpo order, terminators of blocks with multiple successors will not be deal at once +*) +let update_multisucc_term (bl : VBlockLabel.t) = + let block = VProg.get_block !vprog bl in + let binfo = get_allocinfo bl in + match VBlockMap.find_opt !conv_map bl with + | None -> () + | Some block_map -> + let convert_term_label term = Term.term_map_label term (fun label -> VBlockMap.find_exn block_map label) in + let reg_map = binfo.exit_map in + let convert_term term = Term.term_map_reg term (fun slot -> SlotMap.find_exn reg_map slot) in + let new_term = block.term |> convert_term_label |> convert_term in + vprog := VProg.update_block !vprog bl { block with term = new_term }; + () + +(* Main function: used to handle the entire program *) +(* TODO : currently no FP support *) +let alloc_func (f_label : VFuncLabel.t) (func : VFunc.t) = + let rpo_func = RPO.get_func_rpo f_label !rpo in + List.iter alloc_block rpo_func; + + List.iter update_multisucc_term rpo_func; + () + +let reg_alloc (vprog_arg: VProg.t) = + rpo := RPO.calculate_rpo vprog_arg; + (* A. Spill pass. Decide which registers to spill *) + Spill.spill_regs vprog_arg !rpo; + + vprog := vprog_arg; let out = Printf.sprintf "%s-spilled.vasm" !Driver_config.Linkcore_Opt.output_file in - Basic_io.write out (VProg.to_string vprog); - vprog + Basic_io.write out (VProg.to_string !vprog); + + (* B. Allocate registers *) + VFuncMap.iter !vprog.funcs alloc_func; + let out = Printf.sprintf "%s-allocated.vasm" !Driver_config.Linkcore_Opt.output_file in + Basic_io.write out (VProg.to_string !vprog); + + (* TODO : Optimization *) + !vprog diff --git a/src/riscv_reg_util.ml b/src/riscv_reg_util.ml index 0d22d06..995f8b5 100644 --- a/src/riscv_reg_util.ml +++ b/src/riscv_reg_util.ml @@ -75,6 +75,7 @@ module Liveness = struct ; liveIn : SlotSet.t ; liveOut : SlotSet.t ; exitNextUse : int SlotMap.t + ; entryNextUse : int SlotMap.t } (** @@ -90,6 +91,7 @@ module Liveness = struct ; liveIn = SlotSet.empty ; liveOut = SlotSet.empty ; exitNextUse = SlotMap.empty + ; entryNextUse = SlotMap.empty } ;; @@ -151,7 +153,7 @@ module Liveness = struct (* liveOut = union of successors' liveIn *) b_liveOut := SlotSet.union !b_liveOut succ_info.liveIn; (* exitNextUse = accumulate based on successors' entryNextUse *) - SlotMap.iter succ_info.exitNextUse (fun slot dist_in_succ -> + SlotMap.iter succ_info.entryNextUse (fun slot dist_in_succ -> let new_dist = sat_add dist_in_succ 1 in match SlotMap.find_opt !b_exitNextUse slot with | Some old_d -> @@ -221,6 +223,7 @@ module Liveness = struct ; liveIn = !b_liveIn ; liveOut = !b_liveOut ; exitNextUse = !b_exitNextUse + ; entryNextUse = !b_entryNextUse } in (* diff --git a/src/riscv_virtasm.ml b/src/riscv_virtasm.ml index cb20288..25f1cc4 100644 --- a/src/riscv_virtasm.ml +++ b/src/riscv_virtasm.ml @@ -286,6 +286,96 @@ module Inst = struct (* Stack Allocation Directive *) | Alloca of alloca + (* TODO : Refactor *) + let inst_convert (inst : t) (f : Slot.t -> Slot.t) : t = + match inst with + | Add { rd; rs1; rs2 } -> Add { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Addw { rd; rs1; rs2 } -> Addw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sub { rd; rs1; rs2 } -> Sub { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Subw { rd; rs1; rs2 } -> Subw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Mul { rd; rs1; rs2 } -> Mul { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Mulw { rd; rs1; rs2 } -> Mulw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Div { rd; rs1; rs2 } -> Div { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Divw { rd; rs1; rs2 } -> Divw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Divu { rd; rs1; rs2 } -> Divu { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Divuw { rd; rs1; rs2 } -> Divuw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Rem { rd; rs1; rs2 } -> Rem { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Remw { rd; rs1; rs2 } -> Remw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Remu { rd; rs1; rs2 } -> Remu { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Remuw { rd; rs1; rs2 } -> Remuw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sextw { rd; rs } -> Sextw { rd = f rd; rs = f rs } + | Zextw { rd; rs } -> Zextw { rd = f rd; rs = f rs } + | Addi { rd; rs1; imm } -> Addi { rd = f rd; rs1 = f rs1; imm } + | Addiw { rd; rs1; imm } -> Addiw { rd = f rd; rs1 = f rs1; imm } + | And { rd; rs1; rs2 } -> And { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Or { rd; rs1; rs2 } -> Or { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Xor { rd; rs1; rs2 } -> Xor { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Andi { rd; rs1; imm } -> Andi { rd = f rd; rs1 = f rs1; imm } + | Ori { rd; rs1; imm } -> Ori { rd = f rd; rs1 = f rs1; imm } + | Xori { rd; rs1; imm } -> Xori { rd = f rd; rs1 = f rs1; imm } + | Slt { rd; rs1; rs2 } -> Slt { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sltw { rd; rs1; rs2 } -> Sltw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sltu { rd; rs1; rs2 } -> Sltu { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sltuw { rd; rs1; rs2 } -> Sltuw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Slti { rd; rs1; imm } -> Slti { rd = f rd; rs1 = f rs1; imm } + | Sltiw { rd; rs1; imm } -> Sltiw { rd = f rd; rs1 = f rs1; imm } + | Sll { rd; rs1; rs2 } -> Sll { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sllw { rd; rs1; rs2 } -> Sllw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Srl { rd; rs1; rs2 } -> Srl { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Srlw { rd; rs1; rs2 } -> Srlw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sra { rd; rs1; rs2 } -> Sra { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Sraw { rd; rs1; rs2 } -> Sraw { rd = f rd; rs1 = f rs1; rs2 = f rs2 } + | Slli { rd; rs1; imm } -> Slli { rd = f rd; rs1 = f rs1; imm } + | Slliw { rd; rs1; imm } -> Slliw { rd = f rd; rs1 = f rs1; imm } + | Srli { rd; rs1; imm } -> Srli { rd = f rd; rs1 = f rs1; imm } + | Srliw { rd; rs1; imm } -> Srliw { rd = f rd; rs1 = f rs1; imm } + | Srai { rd; rs1; imm } -> Srai { rd = f rd; rs1 = f rs1; imm } + | Sraiw { rd; rs1; imm } -> Sraiw { rd = f rd; rs1 = f rs1; imm } + | Lb { rd; base; offset } -> Lb { rd = f rd; base = f base; offset } + | Lbu { rd; base; offset } -> Lbu { rd = f rd; base = f base; offset } + | Lh { rd; base; offset } -> Lh { rd = f rd; base = f base; offset } + | Lhu { rd; base; offset } -> Lhu { rd = f rd; base = f base; offset } + | Lw { rd; base; offset } -> Lw { rd = f rd; base = f base; offset } + | Ld { rd; base; offset } -> Ld { rd = f rd; base = f base; offset } + | Sb { rd; base; offset } -> Sb { rd = f rd; base = f base; offset } + | Sh { rd; base; offset } -> Sh { rd = f rd; base = f base; offset } + | Sw { rd; base; offset } -> Sw { rd = f rd; base = f base; offset } + | Sd { rd; base; offset } -> Sd { rd = f rd; base = f base; offset } + | FaddD { frd; frs1; frs2 } -> FaddD { frd = f frd; frs1 = f frs1; frs2 = f frs2 } + | FsubD { frd; frs1; frs2 } -> FsubD { frd = f frd; frs1 = f frs1; frs2 = f frs2 } + | FmulD { frd; frs1; frs2 } -> FmulD { frd = f frd; frs1 = f frs1; frs2 = f frs2 } + | FdivD { frd; frs1; frs2 } -> FdivD { frd = f frd; frs1 = f frs1; frs2 = f frs2 } + | FmaddD { frd; frs1; frs2; frs3 } -> FmaddD { frd = f frd; frs1 = f frs1; frs2 = f frs2; frs3 = f frs3 } + | FmsubD { frd; frs1; frs2; frs3 } -> FmsubD { frd = f frd; frs1 = f frs1; frs2 = f frs2; frs3 = f frs3 } + | FnmaddD { frd; frs1; frs2; frs3 } -> FnmaddD { frd = f frd; frs1 = f frs1; frs2 = f frs2; frs3 = f frs3 } + | FnmsubD { frd; frs1; frs2; frs3 } -> FnmsubD { frd = f frd; frs1 = f frs1; frs2 = f frs2; frs3 = f frs3 } + | FeqD { rd; frs1; frs2 } -> FeqD { rd = f rd; frs1 = f frs1; frs2 = f frs2 } + | FltD { rd; frs1; frs2 } -> FltD { rd = f rd; frs1 = f frs1; frs2 = f frs2 } + | FleD { rd; frs1; frs2 } -> FleD { rd = f rd; frs1 = f frs1; frs2 = f frs2 } + | FcvtDW { frd; rs } -> FcvtDW { frd = f frd; rs = f rs } + | FcvtDL { frd; rs } -> FcvtDL { frd = f frd; rs = f rs } + | FcvtLD { rd; frs } -> FcvtLD { rd = f rd; frs = f frs } + | FcvtWDRtz { rd; frs } -> FcvtWDRtz { rd = f rd; frs = f frs } + | FsqrtD { frd; frs } -> FsqrtD { frd = f frd; frs = f frs } + | FabsD { frd; frs } -> FabsD { frd = f frd; frs = f frs } + | FnegD { frd; frs } -> FnegD { frd = f frd; frs = f frs } + | FmvD { frd; frs } -> FmvD { frd = f frd; frs = f frs } + | FmvDX { frd; rs } -> FmvDX { frd = f frd; rs = f rs } + | FmvDXZero { frd } -> FmvDXZero { frd = f frd } + | Fld { frd; base; offset } -> Fld { frd = f frd; base = f base; offset } + | Fsd { frd; base; offset } -> Fsd { frd = f frd; base = f base; offset } + | La { rd; label } -> La { rd = f rd; label } + | Li { rd; imm } -> Li { rd = f rd; imm } + | Mv { rd; rs } -> Mv { rd = f rd; rs = f rs } + | Call { rd; fn; args; fargs } -> Call { rd = f rd; fn; args = List.map f args; fargs = List.map f fargs } + | CallIndirect { rd; fn; args; fargs } -> CallIndirect { rd = f rd; fn = f fn; args = List.map f args; fargs = List.map f fargs } + | Spill { target; origin } -> Spill { target = f target; origin } + | Reload { target; origin } -> Reload { target = f target; origin } + | FSpill { target; origin } -> FSpill { target = f target; origin } + | FReload { target; origin } -> FReload { target = f target; origin } + | Alloca { rd; size } -> Alloca { rd = f rd; size } + ;; + let inst_map (inst : t) (rd : Slot.t -> Slot.t list) (rs : Slot.t -> Slot.t list) = match inst with | Add r_slot | Addw r_slot @@ -328,8 +418,8 @@ module Inst = struct | FsqrtD assign_fslot | FabsD assign_fslot | FnegD assign_fslot | FmvD assign_fslot -> rd assign_fslot.frd @ rs assign_fslot.frs | Fld mem_fslot | Fsd mem_fslot -> rd mem_fslot.frd @ rs mem_fslot.base - | La assign_label -> [] - | Li assign_int64 -> [] + | La assign_label -> rd assign_label.rd + | Li assign_int64 -> rd assign_int64.rd | Mv assign_slot | Sextw assign_slot | Zextw assign_slot -> rd assign_slot.rd @ rs assign_slot.rs | FmvDX assign_direct -> rd assign_direct.frd @ rs assign_direct.rs @@ -356,13 +446,13 @@ module Inst = struct let adjust_rec_alloc_I (inst : t) (pre_K : int) : int = match inst with - | Call _ | CallIndirect _ -> pre_K - List.length Reg.caller_saved_regs + (* | Call _ | CallIndirect _ -> pre_K - List.length Reg.caller_saved_regs *) | _ -> pre_K ;; let adjust_rec_alloc_F (inst : t) (pre_K : int) : int = match inst with - | Call _ | CallIndirect _ -> pre_K - List.length FReg.caller_saved_fregs + (* | Call _ | CallIndirect _ -> pre_K - List.length FReg.caller_saved_fregs *) | _ -> pre_K let to_string x = @@ -517,6 +607,36 @@ module Term = struct | TailCall of call_data | TailCallIndirect of call_indirect | Ret of Slot.t (* Unit for no return*) + ;; + + let term_map_reg (term : t) (f : Slot.t -> Slot.t) : t = + match term with + | Beq { rs1; rs2; ifso; ifnot } -> Beq { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Bne { rs1; rs2; ifso; ifnot } -> Bne { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Blt { rs1; rs2; ifso; ifnot } -> Blt { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Bge { rs1; rs2; ifso; ifnot } -> Bge { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Bltu { rs1; rs2; ifso; ifnot } -> Bltu { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Bgeu { rs1; rs2; ifso; ifnot } -> Bgeu { rs1 = f rs1; rs2 = f rs2; ifso; ifnot } + | Jalr { rd; rs1; offset } -> Jalr { rd = f rd; rs1 = f rs1; offset } + | TailCall { rd; fn; args; fargs } -> TailCall { rd = f rd; fn; args = List.map f args; fargs = List.map f fargs } + | TailCallIndirect { rd; fn; args; fargs } -> TailCallIndirect { rd = f rd; fn = f fn; args = List.map f args; fargs = List.map f fargs } + | Ret var -> Ret (f var) + | x -> x + ;; + + let term_map_label (term : t) (f : VBlockLabel.t -> VBlockLabel.t) : t = + match term with + | Beq { rs1; rs2; ifso; ifnot } -> Beq { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | Bne { rs1; rs2; ifso; ifnot } -> Bne { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | Blt { rs1; rs2; ifso; ifnot } -> Blt { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | Bge { rs1; rs2; ifso; ifnot } -> Bge { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | Bltu { rs1; rs2; ifso; ifnot } -> Bltu { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | Bgeu { rs1; rs2; ifso; ifnot } -> Bgeu { rs1; rs2; ifso = f ifso; ifnot = f ifnot } + | J label -> J (f label) + | Jal label -> Jal (f label) + | x -> x + ;; + let get_srcs (term : t) : Slot.t list = match term with @@ -647,6 +767,10 @@ module VProg = struct | Some x -> x ;; + let update_block (vprog : t) (bl : VBlockLabel.t) (b : VBlock.t) : t = + { vprog with blocks = VBlockMap.add vprog.blocks bl b } + ;; + let get_func (vprog : t) (fn : VFuncLabel.t) : VFunc.t = match VFuncMap.find_opt vprog.funcs fn with | None -> failwith "get_func: function not found"