diff --git a/hw/rtl/VX_gpu_pkg.sv b/hw/rtl/VX_gpu_pkg.sv index fc334c2f9..762a59121 100644 --- a/hw/rtl/VX_gpu_pkg.sv +++ b/hw/rtl/VX_gpu_pkg.sv @@ -467,7 +467,9 @@ package VX_gpu_pkg; `ifdef EXT_TCU_ENABLE - localparam INST_TCU_WMMA = 4'h0; + localparam INST_TCU_WMMA = 4'h0; + localparam INST_TCU_WMMA_SP = 4'h1; + localparam INST_TCU_META_STORE = 4'h2; localparam INST_TCU_BITS = 4; `endif @@ -569,9 +571,10 @@ package VX_gpu_pkg; `ifdef EXT_TCU_ENABLE typedef struct packed { - logic [(INST_ARGS_BITS-16)-1:0] __padding; + logic [(INST_ARGS_BITS-20)-1:0] __padding; logic [3:0] fmt_d; logic [3:0] fmt_s; + logic [3:0] step_k; logic [3:0] step_n; logic [3:0] step_m; } tcu_args_t; diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 30a8ca738..7333ed471 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -555,21 +555,29 @@ module VX_decode import VX_gpu_pkg::*; #( `endif `ifdef EXT_TCU_ENABLE 7'h02: begin - case (funct3) - 3'h0: begin // WMMA_SYNC - ex_type = EX_TCU; - op_type = INST_OP_BITS'(INST_TCU_WMMA); - op_args.tcu.fmt_s = rs1[3:0]; - op_args.tcu.fmt_d = rd[3:0]; - op_args.tcu.step_m = '0; - op_args.tcu.step_n = '0; - `USED_FREG (rd); - `USED_FREG (rs1); - `USED_FREG (rs2); - `USED_FREG (rs3); - end - default:; - endcase + if (funct3 == 3'h0 || funct3 == 3'h1) begin + ex_type = EX_TCU; + op_type = funct3[0] ? INST_OP_BITS'(INST_TCU_WMMA_SP) + : INST_OP_BITS'(INST_TCU_WMMA); + op_args.tcu.fmt_s = rs1[3:0]; + op_args.tcu.fmt_d = rd[3:0]; + op_args.tcu.step_m = '0; + op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; + `USED_FREG (rd); + `USED_FREG (rs1); + `USED_FREG (rs2); + `USED_FREG (rs3); + end else if (funct3 == 3'h2) begin + ex_type = EX_TCU; + op_type = INST_OP_BITS'(INST_TCU_META_STORE); + op_args.tcu.fmt_d = rd[3:0]; // col_idx + op_args.tcu.fmt_s = '0; + op_args.tcu.step_m = '0; + op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; + `USED_FREG (rs1); // source float register + end end `endif default:; diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index e01268e00..44bb3f50d 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -45,7 +45,9 @@ module VX_uop_sequencer import `ifdef EXT_TCU_ENABLE - assign is_base_uop_input = (input_if.data.ex_type == EX_TCU && input_if.data.op_type == INST_TCU_WMMA); + assign is_base_uop_input = (input_if.data.ex_type == EX_TCU + && (input_if.data.op_type == INST_TCU_WMMA + || input_if.data.op_type == INST_TCU_WMMA_SP)); VX_tcu_uops tcu_uops ( .clk (clk), diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index fc5cc7e87..840067420 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -56,17 +56,41 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam PIPE_LATENCY = FEDP_LATENCY + 1; localparam MDATA_QUEUE_DEPTH = 1 << $clog2(PIPE_LATENCY); - localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE); - localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE); - localparam OFF_W = $clog2(TCU_BLOCK_CAP); + localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE); + localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE); + localparam LG_B_BS_SP = $clog2(TCU_B_BLOCK_SIZE_SP); + localparam OFF_W = $clog2(TCU_BLOCK_CAP); + + wire is_sparse = (execute_if.data.op_type == INST_TCU_WMMA_SP); + wire is_meta_store = (execute_if.data.op_type == INST_TCU_META_STORE); wire [3:0] step_m = execute_if.data.op_args.tcu.step_m; wire [3:0] step_n = execute_if.data.op_args.tcu.step_n; + wire [3:0] step_k = execute_if.data.op_args.tcu.step_k; wire [3:0] fmt_s = execute_if.data.op_args.tcu.fmt_s; wire [3:0] fmt_d = execute_if.data.op_args.tcu.fmt_d; - `UNUSED_VAR ({step_m, step_n, fmt_s, fmt_d, execute_if.data}); + wire [`LOG2UP(`NUM_WARPS)-1:0] wid = execute_if.data.header.wid; + + // meta_store: extract per-row write data from rs1_data lanes + localparam PER_WARP_DEPTH = TCU_M_STEPS * (TCU_K_STEPS / 2); + wire meta_wr_en = execute_fire && is_meta_store; + wire [PER_WARP_DEPTH-1:0][31:0] meta_wr_data; + for (genvar r = 0; r < PER_WARP_DEPTH; ++r) begin : g_meta_wr + assign meta_wr_data[r] = 32'(execute_if.data.rs1_data[r]); + end + + // meta_store: force rd=0 in mdata_queue header (x0 write is harmless) + tcu_header_t mdata_queue_in; + always_comb begin + mdata_queue_in = execute_if.data.header; + if (is_meta_store) begin + mdata_queue_in.rd = '0; + end + end + + `UNUSED_VAR ({step_m, step_n, step_k, fmt_s, fmt_d, execute_if.data}); wire mdata_queue_full; @@ -103,7 +127,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( .reset (reset), .push (execute_fire), .pop (result_fire), - .data_in(execute_if.data.header), + .data_in(mdata_queue_in), .data_out(result_if.data.header), `UNUSED_PIN(empty), `UNUSED_PIN(alm_empty), @@ -113,18 +137,68 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( ); wire [OFF_W-1:0] a_off = (OFF_W'(step_m) & OFF_W'(TCU_A_SUB_BLOCKS-1)) << LG_A_BS; - wire [OFF_W-1:0] b_off = (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; + wire [OFF_W-1:0] b_off = is_sparse + ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP + : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; + // 2:4 sparsity metadata +`ifndef TCU_ITYPE_BITS +`define TCU_ITYPE_BITS 8 +`endif + localparam I_RATIO = 32 / `TCU_ITYPE_BITS; // Elements per 32-bit word + localparam META_BLOCK_WIDTH = TCU_NT * 2 * I_RATIO; + localparam META_ROW_WIDTH = TCU_TC_K * 2 * I_RATIO; + localparam ELT_W = 32 / I_RATIO; // bits per element (8 for int8) + wire [META_BLOCK_WIDTH-1:0] vld_meta_block; + + VX_tcu_meta #( + .INSTANCE_ID (INSTANCE_ID), + .META_BLOCK_WIDTH(META_BLOCK_WIDTH), + .PER_WARP_DEPTH (PER_WARP_DEPTH) + ) tcu_meta ( + .clk (clk), + .reset (reset), + .raddr_wid (wid), + .step_m (step_m), + .step_k (step_k), + .vld_meta_block(vld_meta_block), + .wr_en (meta_wr_en), + .wr_wid (wid), + .wr_col_idx (fmt_d), + .wr_data (meta_wr_data) + ); + for (genvar i = 0; i < TCU_TC_M; ++i) begin : g_i for (genvar j = 0; j < TCU_TC_N; ++j) begin : g_j - wire [TCU_TC_K-1:0][31:0] a_row, b_col; + wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_dense, b_col_sparse, b_col_1, b_col_2; for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign - assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); - assign b_col[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); + assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); + assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); + /* verilator lint_off UNUSEDSIGNAL */ + wire [TCU_MAX_INPUTS-1:0] vld_mask = '1; // TODO: should connect to input source + /* verilator lint_on UNUSEDSIGNAL */ + wire [META_ROW_WIDTH-1:0] vld_meta_row = vld_meta_block[META_ROW_WIDTH*i +: META_ROW_WIDTH]; + + VX_tcu_sel #( + .INSTANCE_ID (INSTANCE_ID), + .META_ROW_WIDTH (META_ROW_WIDTH), + .I_RATIO (I_RATIO), + .ELT_W (ELT_W) + ) tcu_sel ( + .b_col_1 (b_col_1), + .b_col_2 (b_col_2), + .vld_meta_row (vld_meta_row), + .b_col (b_col_sparse) + ); + + // Select dense or sparse B column + assign b_col = is_sparse ? b_col_sparse : b_col_dense; wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv new file mode 100644 index 000000000..21033636f --- /dev/null +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -0,0 +1,100 @@ +// Copyright 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_define.vh" + +/* verilator lint_off UNUSEDSIGNAL */ + +module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( + parameter `STRING INSTANCE_ID = "", + parameter META_BLOCK_WIDTH = 64, + parameter PER_WARP_DEPTH = 4 +) ( + input wire clk, + input wire reset, + + // Read port (from FEDP path) + input wire [`LOG2UP(`NUM_WARPS)-1:0] raddr_wid, + input wire [3:0] step_m, + input wire [3:0] step_k, + output wire [META_BLOCK_WIDTH-1:0] vld_meta_block, + + // Write port (meta_store instruction) + input wire wr_en, + input wire [`LOG2UP(`NUM_WARPS)-1:0] wr_wid, + input wire [3:0] wr_col_idx, + input wire [PER_WARP_DEPTH-1:0][31:0] wr_data +); + `UNUSED_SPARAM (INSTANCE_ID) + + // Local parameters + localparam HALF_K_STEPS = TCU_K_STEPS / 2; + localparam TOTAL_DEPTH = `NUM_WARPS * PER_WARP_DEPTH; + localparam ADDRW = `CLOG2(TOTAL_DEPTH); + localparam ADDRW_PW = `CLOG2(PER_WARP_DEPTH); + localparam NUM_COLS = META_BLOCK_WIDTH / 32; + + // Metadata register array (per-warp partitioned) + reg [META_BLOCK_WIDTH-1:0] meta_mem [0:TOTAL_DEPTH-1]; + + // Read address: bit-concatenation of step_m and step_k (pure wire routing, zero delay) + // Use generate-if to avoid zero-width bit-selects when a dimension has only 1 step + localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); + localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); + + wire [ADDRW_PW-1:0] per_warp_raddr; + generate + if (K_STEP_BITS > 0 && M_STEP_BITS > 0) begin : g_addr_mk + assign per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + end else if (K_STEP_BITS > 0) begin : g_addr_k + assign per_warp_raddr = step_k[K_STEP_BITS-1:0]; + end else if (M_STEP_BITS > 0) begin : g_addr_m + assign per_warp_raddr = step_m[M_STEP_BITS-1:0]; + end else begin : g_addr_zero + assign per_warp_raddr = '0; + end + endgenerate + wire [ADDRW-1:0] read_addr = {raddr_wid, per_warp_raddr}; + + // Combinational read + assign vld_meta_block = meta_mem[read_addr]; + + // Post-reset init counter: fills all warps with alternating patterns + reg [ADDRW:0] init_counter; + wire init_active = ~init_counter[ADDRW]; + wire [ADDRW-1:0] init_addr = init_counter[ADDRW-1:0]; + wire [META_BLOCK_WIDTH-1:0] init_data = init_addr[0] ? + {(META_BLOCK_WIDTH/4){4'b1010}} : + {(META_BLOCK_WIDTH/4){4'b0101}}; + + // Write logic: init or runtime meta_store + always_ff @(posedge clk) begin + if (reset) begin + init_counter <= 0; + end else if (init_active) begin + meta_mem[init_addr] <= init_data; + init_counter <= init_counter + 1; + end else if (wr_en) begin + for (int row = 0; row < PER_WARP_DEPTH; row++) begin + for (int col = 0; col < NUM_COLS; col++) begin + if (col == int'(wr_col_idx)) begin + meta_mem[{wr_wid, ADDRW_PW'(row)}][col*32 +: 32] <= wr_data[row]; + end + end + end + end + end + +endmodule + +/* verilator lint_on UNUSEDSIGNAL */ diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index 0bb3fed28..1f678416a 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -75,10 +75,14 @@ package VX_tcu_pkg; localparam TCU_A_BLOCK_SIZE = TCU_TC_M * TCU_TC_K; localparam TCU_A_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_A_BLOCK_SIZE; - // B micro-tiling + // B micro-tiling (dense) localparam TCU_B_BLOCK_SIZE = TCU_TC_K * TCU_TC_N; localparam TCU_B_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE; + // B micro-tiling (sparse 2:4) + localparam TCU_B_BLOCK_SIZE_SP = (TCU_TC_K * TCU_TC_N) * 2; + localparam TCU_B_SUB_BLOCKS_SP = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP; + // Register counts //localparam TCU_NRA = (TCU_TILE_M * TCU_TILE_K) / TCU_NT; localparam TCU_NRB = (TCU_TILE_N * TCU_TILE_K) / TCU_NT; @@ -191,13 +195,17 @@ package VX_tcu_pkg; input op_args_t op_args ); case (INST_TCU_BITS'(op_type)) - INST_TCU_WMMA: begin - `TRACE(level, ("WMMA.")); + INST_TCU_WMMA, + INST_TCU_WMMA_SP: begin + `TRACE(level, (INST_TCU_BITS'(op_type) == INST_TCU_WMMA_SP ? "WMMA_SP." : "WMMA.")); trace_fmt(level, op_args.tcu.fmt_s); `TRACE(level, (".")); trace_fmt(level, op_args.tcu.fmt_d); `TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n)); end + INST_TCU_META_STORE: begin + `TRACE(level, ("META_STORE.col%0d", op_args.tcu.fmt_d)); + end default: `TRACE(level, ("?")) endcase endtask diff --git a/hw/rtl/tcu/VX_tcu_sel.sv b/hw/rtl/tcu/VX_tcu_sel.sv new file mode 100644 index 000000000..b3b6815d3 --- /dev/null +++ b/hw/rtl/tcu/VX_tcu_sel.sv @@ -0,0 +1,117 @@ +// Copyright 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_define.vh" + +/* verilator lint_off UNUSEDSIGNAL */ + +module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( + parameter `STRING INSTANCE_ID = "", + parameter META_ROW_WIDTH = 16, + parameter I_RATIO = 4, + parameter ELT_W = 8 +) ( + input wire [TCU_TC_K-1:0][31:0] b_col_1, + input wire [TCU_TC_K-1:0][31:0] b_col_2, + input wire [META_ROW_WIDTH-1:0] vld_meta_row, + output wire [TCU_TC_K-1:0][31:0] b_col +); + `UNUSED_SPARAM (INSTANCE_ID); + + for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux + + if (I_RATIO == 4) begin : g_ratio4 + // int8: select 2 valid from each 4-element group + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + + wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + + wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + + assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; + + end else if (I_RATIO == 2) begin : g_ratio2 + // fp16: select 2 valid from combined 4-element group + wire [I_RATIO-1:0] mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + wire [3:0] grp_mask = {mask_hi, mask_lo}; + + wire [ELT_W-1:0] elem0 = b_col_1[k][0 +: ELT_W]; + wire [ELT_W-1:0] elem1 = b_col_1[k][ELT_W +: ELT_W]; + wire [ELT_W-1:0] elem2 = b_col_2[k][0 +: ELT_W]; + wire [ELT_W-1:0] elem3 = b_col_2[k][ELT_W +: ELT_W]; + + // First valid (LSB), last valid (MSB) + wire [ELT_W-1:0] sel_0 = grp_mask[0] ? elem0 : + grp_mask[1] ? elem1 : + grp_mask[2] ? elem2 : elem3; + + wire [ELT_W-1:0] sel_1 = grp_mask[3] ? elem3 : + grp_mask[2] ? elem2 : + grp_mask[1] ? elem1 : elem0; + + assign b_col[k] = {sel_1, sel_0}; + + end else if (I_RATIO == 8) begin : g_ratio8 + // int4: 4 sub-groups of 4 nibbles each + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + wire [3:0] sg0_mask = grp_mask_lo[3:0]; + wire [3:0] sg1_mask = grp_mask_lo[7:4]; + wire [3:0] sg2_mask = grp_mask_hi[3:0]; + wire [3:0] sg3_mask = grp_mask_hi[7:4]; + + wire [ELT_W-1:0] sg0_0 = sg0_mask[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + sg0_mask[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg0_1 = sg0_mask[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + sg0_mask[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg1_0 = sg1_mask[0] ? b_col_1[k][4*ELT_W +: ELT_W] : + sg1_mask[1] ? b_col_1[k][5*ELT_W +: ELT_W] : + b_col_1[k][6*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg1_1 = sg1_mask[3] ? b_col_1[k][7*ELT_W +: ELT_W] : + sg1_mask[2] ? b_col_1[k][6*ELT_W +: ELT_W] : + b_col_1[k][5*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg2_0 = sg2_mask[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + sg2_mask[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg2_1 = sg2_mask[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + sg2_mask[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg3_0 = sg3_mask[0] ? b_col_2[k][4*ELT_W +: ELT_W] : + sg3_mask[1] ? b_col_2[k][5*ELT_W +: ELT_W] : + b_col_2[k][6*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg3_1 = sg3_mask[3] ? b_col_2[k][7*ELT_W +: ELT_W] : + sg3_mask[2] ? b_col_2[k][6*ELT_W +: ELT_W] : + b_col_2[k][5*ELT_W +: ELT_W]; + + assign b_col[k] = {sg3_1, sg3_0, sg2_1, sg2_0, sg1_1, sg1_0, sg0_1, sg0_0}; + end + + end + +endmodule + +/* verilator lint_on UNUSEDSIGNAL */ diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index 568f83880..40f0353e9 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -33,8 +33,12 @@ module VX_tcu_uops import localparam LG_M = $clog2(TCU_M_STEPS); localparam LG_K = $clog2(TCU_K_STEPS); - localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); - localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); + localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); + localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); + localparam LG_B_SB_SP = $clog2(TCU_B_SUB_BLOCKS_SP); + + wire is_sparse_in = (ibuf_in.op_type == INST_TCU_WMMA_SP); + reg is_sparse; // uop counter reg [CTR_W-1:0] counter; @@ -61,9 +65,15 @@ module VX_tcu_uops import assign k_index = 0; end - // Register offsets - wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); - wire [CTR_W-1:0] rs2_offset = ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; + // Register offsets — dense vs sparse formulas + wire [CTR_W-1:0] rs1_offset = is_sparse + ? ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index) + : ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); + + wire [CTR_W-1:0] rs2_offset = is_sparse + ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP + : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; + wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); // Register calculations @@ -88,6 +98,7 @@ module VX_tcu_uops import assign ibuf_out.op_args.tcu.fmt_d = ibuf_in.op_args.tcu.fmt_d; assign ibuf_out.op_args.tcu.step_m = 4'(m_index); assign ibuf_out.op_args.tcu.step_n = 4'(n_index); + assign ibuf_out.op_args.tcu.step_k = 4'(k_index); assign ibuf_out.wb = 1; assign ibuf_out.rd_xregs = ibuf_in.rd_xregs; assign ibuf_out.wr_xregs = ibuf_in.wr_xregs; @@ -106,16 +117,21 @@ module VX_tcu_uops import always_ff @(posedge clk) begin if (reset) begin - counter <= 0; - busy <= 0; - done <= 0; + counter <= 0; + busy <= 0; + done <= 0; + is_sparse <= 0; end else begin if (~busy && start) begin - busy <= 1; - done <= (TCU_UOPS == 1); + counter <= 0; + busy <= 1; + is_sparse <= is_sparse_in; + done <= is_sparse_in ? (TCU_UOPS/2 == 1) : (TCU_UOPS == 1); end else if (busy && next) begin counter <= counter + ((TCU_UOPS > 1) ? 1 : 0); - done <= (counter == CTR_W'(TCU_UOPS-2)); + done <= is_sparse + ? (counter == CTR_W'((TCU_UOPS/2)-2)) + : (counter == CTR_W'(TCU_UOPS-2)); busy <= ~done; end end diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 3749cdf1e..4a1112a92 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -117,7 +117,6 @@ namespace detail { return *reinterpret_cast(&result_u); } }; - } template + template static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) { uint32_t lane = vx_thread_id(); if constexpr (Frag::Use == matrix_a) { @@ -188,61 +187,135 @@ struct wmma_context { if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } - auto base = reinterpret_cast(src) + block_row * ldm + block_col; - detail::unroll_for([&](auto r) { - uint32_t block_m = r / cfg::k_steps; - uint32_t block_k = r % cfg::k_steps; - uint32_t elem_row = block_m * m_stride; - uint32_t elem_col = block_k * k_stride; - if constexpr (src_layout == col_major) { - static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); - std::swap(elem_row, elem_col); - auto ptr = base + elem_row * ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + if constexpr (sparse) { + // Sparse A load: only load half the K-steps (compressed A) + constexpr uint32_t sparse_k_steps = cfg::k_steps / 2; + constexpr uint32_t sparse_regs = cfg::m_steps * sparse_k_steps; + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_m = r / sparse_k_steps; + uint32_t block_k = r % sparse_k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } + } else { + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + // row_major layout + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); + } else { + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); + } } - } else { - // raw_major layout - auto ptr = base + elem_row * ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); - } - }); - } else if constexpr (Frag::Use == matrix_b) { - // Load column-major matrix B - uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); - uint32_t lane_in_blk = (cfg::b_block_size == NT) ? lane : (lane % cfg::b_block_size); - uint32_t block_col = (lane_in_blk / cfg::tcK) + (block_idx * cfg::tcN); - uint32_t block_row = (lane_in_blk % cfg::tcK) * i_ratio; - uint32_t n_stride = cfg::b_sub_blocks * cfg::tcN; - uint32_t k_stride = cfg::tcK * i_ratio; - if constexpr (src_layout == col_major) { - std::swap(block_row, block_col); + }); + } else { + // Dense A load: load all K-steps + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } + } else { + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); + } + }); } - auto base = reinterpret_cast(src) + block_row * ldm + block_col; - detail::unroll_for([&](auto r) { - uint32_t block_k = r / cfg::b_sub_steps; - uint32_t block_n = r % cfg::b_sub_steps; - uint32_t elem_row = block_k * k_stride; - uint32_t elem_col = block_n * n_stride; - if constexpr (src_layout == row_major) { - static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); - auto ptr = base + elem_row * ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + } else if constexpr (Frag::Use == matrix_b) { + if constexpr (sparse) { + // Sparse B load: uses 2x tcK for B block + constexpr uint32_t b_tcK = cfg::tcK * 2; + uint32_t block_idx = (cfg::b_block_size_sp == NT) ? 0 : (lane / cfg::b_block_size_sp); + uint32_t lane_in_blk = (cfg::b_block_size_sp == NT) ? lane : (lane % cfg::b_block_size_sp); + uint32_t block_col = (lane_in_blk / b_tcK) + (block_idx * cfg::tcN); + uint32_t block_row = (lane_in_blk % b_tcK) * i_ratio; + uint32_t n_stride = cfg::b_sub_blocks_sp * cfg::tcN; + uint32_t k_stride = b_tcK * i_ratio; + if constexpr (src_layout == col_major) { + std::swap(block_row, block_col); + } + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_k = r / cfg::b_sub_steps_sp; + uint32_t block_n = r % cfg::b_sub_steps_sp; + uint32_t elem_row = block_k * k_stride; + uint32_t elem_col = block_n * n_stride; + if constexpr (src_layout == row_major) { + static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + // col_major layout + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); } - } else { - // col_major layout - std::swap(elem_row, elem_col); - auto ptr = base + elem_row * ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); + }); + } else { + // Dense B load + uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); + uint32_t lane_in_blk = (cfg::b_block_size == NT) ? lane : (lane % cfg::b_block_size); + uint32_t block_col = (lane_in_blk / cfg::tcK) + (block_idx * cfg::tcN); + uint32_t block_row = (lane_in_blk % cfg::tcK) * i_ratio; + uint32_t n_stride = cfg::b_sub_blocks * cfg::tcN; + uint32_t k_stride = cfg::tcK * i_ratio; + if constexpr (src_layout == col_major) { + std::swap(block_row, block_col); } - }); + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_k = r / cfg::b_sub_steps; + uint32_t block_n = r % cfg::b_sub_steps; + uint32_t elem_row = block_k * k_stride; + uint32_t elem_col = block_n * n_stride; + if constexpr (src_layout == row_major) { + static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } + } else { + // col_major layout + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); + } + }); + } } else { // Load accumulator matrix C uint32_t block_row = lane / cfg::tcN; @@ -375,7 +448,30 @@ struct wmma_context { }); } - template + template + static __attribute__((always_inline)) void meta_store(float data) { + __asm__ volatile(".insn r 0x0b, 2, 2, x%[col], %[data], x0" // RISCV_CUSTOM0 instead of 0b + :: [col]"i"(COL), [data]"f"(data)); + } + +// // Set thread mask // "memory" comment stop compiler reordering. +// inline void vx_tmc(int thread_mask) { +// __asm__ volatile (".insn r %0, 0, 0, x0, %1, x0" :: "i"(RISCV_CUSTOM0), "r"(thread_mask) : "memory"); +// } + + + static __attribute__((always_inline)) void load_metadata_sync(const void* meta_ptr) { + constexpr uint32_t rtl_i_ratio = 32 / It::bits; + constexpr uint32_t num_cols = (NT * 2 * rtl_i_ratio) / 32; + uint32_t lane_id = vx_thread_id(); + auto base = reinterpret_cast(meta_ptr); + detail::unroll_for([&](auto col) { + float data = base[lane_id * num_cols + col]; + meta_store(data); + }); + } + + template static __attribute__((always_inline)) void mma_sync(FragD &fragD, const FragA &fragA, const FragB &fragB, const FragC &fragC) { static_assert(FragA::Use == matrix_a, "A must be matrix_a"); static_assert(FragB::Use == matrix_b, "B must be matrix_b"); @@ -423,25 +519,23 @@ struct wmma_context { register float fd6 __asm__("f30"); register float fd7 __asm__("f31"); - __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0" + constexpr int funct3 = sparse ? 1 : 0; + __asm__ volatile (".insn r %[insn], %[f3], 2, x%[fmd], x%[fms], x0" : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + : [insn]"i"(RISCV_CUSTOM0), [f3]"i"(funct3), [fmd]"i"(Ot::id), [fms]"i"(It::id), "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7), "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) ); - // Write results to fragD fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; } else { static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); - // fragB: caller-saved registers (f28-f31) register float fb0 __asm__("f28") = fragB.data[0]; register float fb1 __asm__("f29") = fragB.data[1]; register float fb2 __asm__("f30") = fragB.data[2]; register float fb3 __asm__("f31") = fragB.data[3]; - // fragC: mix of caller-saved (f10-f17) register float fc0 __asm__("f10") = fragC.data[0]; register float fc1 __asm__("f11") = fragC.data[1]; register float fc2 __asm__("f12") = fragC.data[2]; @@ -451,7 +545,6 @@ struct wmma_context { register float fc6 __asm__("f16") = fragC.data[6]; register float fc7 __asm__("f17") = fragC.data[7]; - // Force outputs into accumulator registers register float fd0 __asm__("f10"); register float fd1 __asm__("f11"); register float fd2 __asm__("f12"); @@ -461,158 +554,19 @@ struct wmma_context { register float fd6 __asm__("f16"); register float fd7 __asm__("f17"); - __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0" + constexpr int funct3 = sparse ? 1 : 0; + __asm__ volatile (".insn r %[insn], %[f3], 2, x%[fmd], x%[fms], x0" : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + : [insn]"i"(RISCV_CUSTOM0), [f3]"i"(funct3), [fmd]"i"(Ot::id), [fms]"i"(It::id), "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) ); - // Write results to fragD fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; } } - template - static __attribute__((always_inline)) void mma_sp_sync( - FragD &fragD, - const FragA &fragA, - const FragB &fragB, - const FragC &fragC, - const FragMeta &fragMeta) { - - static_assert(FragA::Use == matrix_a, "A must be matrix_a"); - static_assert(FragB::Use == matrix_b, "B must be matrix_b"); - static_assert(FragC::Use == accumulator, "C must be accumulator"); - static_assert(FragD::Use == accumulator, "D must be accumulator"); - static_assert(FragA::NR <= 8, "Unsupported number of registers for FragA"); - (void)fragMeta; - - // Temporary bring-up mode: - // - optionally bypass loaded A and inject fixed compressed A patterns -#ifndef VX_TCU_SP_USE_LOADED_A -#define VX_TCU_SP_USE_LOADED_A 1 -#endif - - auto bits_to_f32 = [](uint32_t bits) { - union { - uint32_t u; - float f; - } v; - v.u = bits; - return v.f; - }; - - register float fa0 __asm__("f0"); - register float fa1 __asm__("f1"); - register float fa2 __asm__("f2"); - register float fa3 __asm__("f3"); - register float fa4 __asm__("f4"); - register float fa5 __asm__("f5"); - register float fa6 __asm__("f6"); - register float fa7 __asm__("f7"); - - if constexpr (VX_TCU_SP_USE_LOADED_A) { - fa0 = (FragA::NR > 0) ? fragA.data[0] : 0.0f; - fa1 = (FragA::NR > 1) ? fragA.data[1] : 0.0f; - fa2 = (FragA::NR > 2) ? fragA.data[2] : 0.0f; - fa3 = (FragA::NR > 3) ? fragA.data[3] : 0.0f; - fa4 = (FragA::NR > 4) ? fragA.data[4] : 0.0f; - fa5 = (FragA::NR > 5) ? fragA.data[5] : 0.0f; - fa6 = (FragA::NR > 6) ? fragA.data[6] : 0.0f; - fa7 = (FragA::NR > 7) ? fragA.data[7] : 0.0f; - } else { - fa0 = bits_to_f32(0x03020100u); - fa1 = bits_to_f32(0x07060504u); - fa2 = bits_to_f32(0x0b0a0908u); - fa3 = bits_to_f32(0x0f0e0d0cu); - fa4 = 0.0f; - fa5 = 0.0f; - fa6 = 0.0f; - fa7 = 0.0f; - } - - if constexpr (FragB::NR == 8) { - // fragB: caller-saved registers (f10-f17) - register float fb0 __asm__("f10") = fragB.data[0]; - register float fb1 __asm__("f11") = fragB.data[1]; - register float fb2 __asm__("f12") = fragB.data[2]; - register float fb3 __asm__("f13") = fragB.data[3]; - register float fb4 __asm__("f14") = fragB.data[4]; - register float fb5 __asm__("f15") = fragB.data[5]; - register float fb6 __asm__("f16") = fragB.data[6]; - register float fb7 __asm__("f17") = fragB.data[7]; - - // fragC: accumulator registers (f24-f31) - register float fc0 __asm__("f24") = fragC.data[0]; - register float fc1 __asm__("f25") = fragC.data[1]; - register float fc2 __asm__("f26") = fragC.data[2]; - register float fc3 __asm__("f27") = fragC.data[3]; - register float fc4 __asm__("f28") = fragC.data[4]; - register float fc5 __asm__("f29") = fragC.data[5]; - register float fc6 __asm__("f30") = fragC.data[6]; - register float fc7 __asm__("f31") = fragC.data[7]; - - register float fd0 __asm__("f24"); - register float fd1 __asm__("f25"); - register float fd2 __asm__("f26"); - register float fd3 __asm__("f27"); - register float fd4 __asm__("f28"); - register float fd5 __asm__("f29"); - register float fd6 __asm__("f30"); - register float fd7 __asm__("f31"); - - // funct3=1 is sparse WMMA (simx decode support added separately). - __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" - : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), - "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), - "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7), - "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) - ); - - fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; - } else { - static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); - // fragB: caller-saved registers (f28-f31) - register float fb0 __asm__("f28") = fragB.data[0]; - register float fb1 __asm__("f29") = fragB.data[1]; - register float fb2 __asm__("f30") = fragB.data[2]; - register float fb3 __asm__("f31") = fragB.data[3]; - - // fragC: caller-saved registers (f10-f17) - register float fc0 __asm__("f10") = fragC.data[0]; - register float fc1 __asm__("f11") = fragC.data[1]; - register float fc2 __asm__("f12") = fragC.data[2]; - register float fc3 __asm__("f13") = fragC.data[3]; - register float fc4 __asm__("f14") = fragC.data[4]; - register float fc5 __asm__("f15") = fragC.data[5]; - register float fc6 __asm__("f16") = fragC.data[6]; - register float fc7 __asm__("f17") = fragC.data[7]; - - register float fd0 __asm__("f10"); - register float fd1 __asm__("f11"); - register float fd2 __asm__("f12"); - register float fd3 __asm__("f13"); - register float fd4 __asm__("f14"); - register float fd5 __asm__("f15"); - register float fd6 __asm__("f16"); - register float fd7 __asm__("f17"); - - // funct3=1 is sparse WMMA (simx decode support added separately). - __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" - : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), - "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), - "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), - "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) - ); - - fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; - } - } - }; } // namespace tensor diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index e39c7fff2..f633888c7 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -191,10 +191,14 @@ struct wmma_config_t { static constexpr uint32_t a_sub_blocks = block_cap / a_block_size; // number of A micro-tiles per register static constexpr uint32_t a_sub_steps = m_steps / a_sub_blocks; // number of A sub-steps per register - static constexpr uint32_t b_block_size = tcK * tcN; // size of B micro-tile + static constexpr uint32_t b_block_size = tcK * tcN; // size of B micro-tile (dense) static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; // number of B micro-tiles per register static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register + static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // sparse 2:4 + static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; + static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; + static constexpr uint32_t NRA = (xtileM * xtileK) / NT; // Number of A registers static constexpr uint32_t NRB = (xtileN * xtileK) / NT; // Number of B registers static constexpr uint32_t NRC = (xtileM * xtileN) / NT; // Number of C registers @@ -217,6 +221,13 @@ struct wmma_config_t { static constexpr uint32_t tileM = xtileM; static constexpr uint32_t tileN = xtileN; static constexpr uint32_t tileK = xtileK * i_ratio; // Adjusted for input type size + + // Metadata constants for 2:4 structured sparsity + static constexpr uint32_t itype_bits = It::bits; + static constexpr uint32_t rtl_i_ratio = 32 / itype_bits; + static constexpr uint32_t meta_block_width = NT * 2 * rtl_i_ratio; // bits + static constexpr uint32_t meta_cols = meta_block_width / 32; + static constexpr uint32_t per_warp_depth = m_steps * (k_steps / 2); }; } // namespace tensor diff --git a/sim/simx/emulator.cpp b/sim/simx/emulator.cpp index bbb3e6c65..e055cc39e 100644 --- a/sim/simx/emulator.cpp +++ b/sim/simx/emulator.cpp @@ -276,6 +276,11 @@ bool Emulator::wspawn(uint32_t num_warps, Word nextPC) { } uint32_t Emulator::get_barrier_phase(uint32_t bar_id) const { + bool is_global = (bar_id >> 31); + bar_id &= 0x7fffffff; + if (is_global) { + return core_->socket()->get_barrier_phase(bar_id); + } return barriers_.at(bar_id).phase; } diff --git a/tests/regression/sgemm_tcu/kernel.cpp b/tests/regression/sgemm_tcu/kernel.cpp index 83bccc0e7..3afff4d57 100644 --- a/tests/regression/sgemm_tcu/kernel.cpp +++ b/tests/regression/sgemm_tcu/kernel.cpp @@ -1,6 +1,7 @@ #include "common.h" #include #include +#include namespace vt = vortex::tensor; using ctx = vt::wmma_context; diff --git a/tests/regression/sgemm_tcu/main.cpp b/tests/regression/sgemm_tcu/main.cpp index 0d6bb12ca..8828cbebe 100644 --- a/tests/regression/sgemm_tcu/main.cpp +++ b/tests/regression/sgemm_tcu/main.cpp @@ -827,7 +827,7 @@ int main(int argc, char *argv[]) { // upload matrix B buffer { std::cout << "upload matrix B buffer" << std::endl; - if constexpr (std::is_same::value || + if constexpr (std::is_same::value || std::is_same::value || std::is_same::value) { // sub-byte matrix B must be in col-major format diff --git a/tests/regression/sgemm_tcu_struct_sparse/Makefile b/tests/regression/sgemm_tcu_struct_sparse/Makefile new file mode 100644 index 000000000..e2c7b0ee0 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/Makefile @@ -0,0 +1,16 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := sgemm_tcu_struct_sparse + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp $(SW_COMMON_DIR)/rvfloats.cpp $(SW_COMMON_DIR)/softfloat_ext.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + +CXXFLAGS += -I$(THIRD_PARTY_DIR)/softfloat/source/include + +LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a + +include ../common.mk \ No newline at end of file diff --git a/tests/regression/sgemm_tcu_struct_sparse/common.h b/tests/regression/sgemm_tcu_struct_sparse/common.h new file mode 100644 index 000000000..eaaf6b5fb --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/common.h @@ -0,0 +1,29 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +#ifndef NUM_THREADS +#define NUM_THREADS 4 +#endif + +#ifndef ITYPE +#define ITYPE fp16 +#endif + +#ifndef OTYPE +#define OTYPE fp32 +#endif + +typedef struct { + uint32_t grid_dim[2]; + uint32_t block_dim[2]; + uint32_t M, N, K; + uint64_t A_addr; + uint64_t B_addr; + uint64_t C_addr; + uint64_t meta_addr; + uint64_t tcu_cycles_addr; +} kernel_arg_t; + +#endif diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp new file mode 100644 index 000000000..e79c3c715 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -0,0 +1,68 @@ +#include "common.h" +#include +#include +#include + +namespace vt = vortex::tensor; +using ctx = vt::wmma_context; + +void kernel_body(kernel_arg_t *__UNIFORM__ arg) { + auto pA = reinterpret_cast(arg->A_addr); + auto pB = reinterpret_cast(arg->B_addr); + auto pC = reinterpret_cast(arg->C_addr); + auto pMetaBase = reinterpret_cast(arg->meta_addr); + + uint32_t M = arg->M; + uint32_t N = arg->N; + uint32_t K = arg->K; + + ctx::fragment_a fragA; + ctx::fragment_b fragB; + ctx::fragment_acc fragC; + + uint32_t tile_row = blockIdx.y * ctx::tileM; + uint32_t tile_col = blockIdx.x * ctx::tileN; + + ctx::fill_fragment(fragC, 0); + + // Per-K-tile metadata reload + constexpr uint32_t rtl_i_ratio = 32 / vt::ITYPE::bits; + constexpr uint32_t meta_cols = (NUM_THREADS * 2 * rtl_i_ratio) / 32; + constexpr uint32_t per_k_tile_words = NUM_THREADS * meta_cols; + uint32_t num_k_tiles = K / ctx::tileK; + uint32_t tile_row_idx = blockIdx.y; + + uint32_t stride_A = K / 2; + uint32_t cyc_start = csr_read(0xB00); + for (int i = 0; i < (int)K; i += (int)ctx::tileK) { + // Load metadata for this K-tile + uint32_t k_tile = i / ctx::tileK; + auto pMeta = pMetaBase + (tile_row_idx * num_k_tiles + k_tile) * per_k_tile_words; + ctx::load_metadata_sync(pMeta); + + auto pTileA = pA + tile_row * stride_A + (i / 2); + ctx::load_matrix_sync(fragA, pTileA, stride_A); + + if constexpr (vt::ITYPE::bits < 8) { + auto pTileB = pB + tile_col * K + i; + ctx::load_matrix_sync(fragB, pTileB, K); + } else { + auto pTileB = pB + i * N + tile_col; + ctx::load_matrix_sync(fragB, pTileB, N); + } + + ctx::mma_sync(fragC, fragA, fragB, fragC); + } + uint32_t cyc_end = csr_read(0xB00); + auto pCycles = reinterpret_cast(arg->tcu_cycles_addr); + uint32_t block_id = blockIdx.y * arg->grid_dim[0] + blockIdx.x; + pCycles[block_id] = cyc_end - cyc_start; + + auto pTileC = pC + tile_row * N + tile_col; + ctx::store_matrix_sync(pTileC, fragC, N); +} + +int main() { + auto arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp new file mode 100644 index 000000000..1b39316f9 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -0,0 +1,1230 @@ +#include "common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define FLOAT_ULP 6 +#define MAX_ERRORS 100 + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +using namespace vortex; +namespace vt = tensor; + +/////////////////////////////////////////////////////////////////////////////// + +static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) { + // Calculate output size and stride + uint32_t out_bytes = (width * height + 1) / 2; + memset(dst, 0, out_bytes); + uint32_t dst_stride = (height + 1) / 2; // Bytes per column in output + + // For each column in source (which becomes row in destination) + for (uint32_t c = 0; c < width; ++c) { + uint32_t base = c * dst_stride; + + // For each row in source (which becomes column in destination) + for (uint32_t r = 0; r < height; r += 2) { + // Calculate source indices (row-major) + uint32_t idx_even = r * width + c; + uint32_t idx_odd = (r + 1) * width + c; + + // Extract nibbles - consistent with data_accessor_t + uint8_t b_even = src[idx_even / 2]; + uint8_t b_odd = (r + 1 < height) ? src[idx_odd / 2] : 0; + + uint8_t nib_even = (idx_even & 1) ? (b_even >> 4) : (b_even & 0x0F); + uint8_t nib_odd = (r + 1 < height) + ? ((idx_odd & 1) ? (b_odd >> 4) : (b_odd & 0x0F)) + : 0; + + // Pack into destination: even row in low nibble, odd row in high nibble + dst[base + r / 2] = (nib_odd << 4) | nib_even; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +template +struct data_accessor_t { + using Type = typename T::dtype; + static Type read(const Type *ptr, uint32_t offset) { + return ptr[offset]; + } + static void write(Type *ptr, uint32_t offset, Type value) { + ptr[offset] = value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble + } + static void write(uint8_t *ptr, uint32_t offset, int32_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble + } + static void write(uint8_t *ptr, uint32_t offset, int32_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // extract nibble + } + static void write(uint8_t *ptr, uint32_t offset, uint8_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class Comparator {}; + +template <> +class Comparator { +public: + static int8_t generate() { + return (int8_t)rand(); + } + static bool compare(int8_t a, int8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); // store 2 nibbles in a byte + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); // store 2 nibbles in a byte + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static int8_t generate() { + return (int8_t)(rand() % 256 - 128); + } + static bool compare(int8_t a, int8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static int32_t generate() { + return (int32_t)rand(); + } + static bool compare(int32_t a, int32_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint16_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoh_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint16_t a, uint16_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint16_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftob_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint16_t a, uint16_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoe4m3_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoe5m2_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint32_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftotf32_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint32_t a, uint32_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +// TODO: temp arbitrarily hardcoded scale factors +constexpr uint8_t SCALE_FACTOR_E8M0_A = 129; // val = 4, bias = 127 +constexpr uint8_t SCALE_FACTOR_E8M0_B = 131; // val = 16 +constexpr uint8_t SCALE_FACTOR_E4M3_A = 0x41; // val = 2.25, bias = 7 +constexpr uint8_t SCALE_FACTOR_E4M3_B = 0x33; // val = 0.6875 + +template <> +class Comparator { +public: + static uint8_t generate() { + return generate_with_scale(SCALE_FACTOR_E8M0_A); + } + + static uint8_t generate_with_scale(uint8_t scale_factor) { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftomxfp8_s(bit_cast(fvalue), scale_factor, 0, nullptr); + } + + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return generate_with_scale(SCALE_FACTOR_E4M3_A); + } + + static uint8_t generate_with_scale(uint8_t scale_factor) { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftonvfp4_s(bit_cast(fvalue), scale_factor, 0, nullptr); + } + + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static float generate() { + return static_cast(rand()) / RAND_MAX; + } + static bool compare(float a, float b, int index, int errors) { + if constexpr (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + if (a == 0.0f && b == 0.0f) { + return true; + } + //relative error tolerance + auto diff = std::abs((a - b)/b); + if (diff < 0.01f) { + return true; + } + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=%f, actual=%f\n", index, b, a); + } + return false; + } else { + union fi_t { + float f; + int32_t i; + }; + fi_t fa, fb; + fa.f = a; + fb.f = b; + auto d = std::abs(fa.i - fb.i); + if (d > FLOAT_ULP) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=%f, actual=%f\n", index, fb.f, fa.f); + } + return false; + } + return true; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct muladd_t { + using stype = typename S::dtype; + using dtype = typename D::dtype; + static dtype eval(stype a, stype b, dtype c) { + return static_cast(a) * static_cast(b) + c; + } +}; + +template <> +struct muladd_t { + static float eval(uint16_t a, uint16_t b, float c) { + auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { + auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_htof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoh_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint16_t a, uint16_t b, float c) { + auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { + auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_btof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftob_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + auto fa = bit_cast(rv_e4m3tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e4m3tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + auto fa = bit_cast(rv_e4m3tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e4m3tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_e4m3tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoe4m3_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + auto fa = bit_cast(rv_e5m2tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e5m2tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + auto fa = bit_cast(rv_e5m2tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e5m2tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_e5m2tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoe5m2_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint32_t a, uint32_t b, float c) { + auto fa = bit_cast(rv_tf32tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_tf32tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint32_t eval(uint32_t a, uint32_t b, uint32_t c) { + auto fa = bit_cast(rv_tf32tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_tf32tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_tf32tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftotf32_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E8M0_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E8M0_B; + auto fa = bit_cast(rv_mxfp8tof_s(a, sf_a, 0, nullptr)); + auto fb = bit_cast(rv_mxfp8tof_s(b, sf_b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + constexpr uint8_t sf = SCALE_FACTOR_E8M0_A; + auto fa = bit_cast(rv_mxfp8tof_s(a, sf, 0, nullptr)); + auto fb = bit_cast(rv_mxfp8tof_s(b, sf, 0, nullptr)); + auto fc = bit_cast(rv_mxfp8tof_s(c, sf, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftomxfp8_s(bit_cast(fd), sf, 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E4M3_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E4M3_B; + auto fa = bit_cast(rv_nvfp4tof_s(a, sf_a, 0, nullptr)); + auto fb = bit_cast(rv_nvfp4tof_s(b, sf_b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + constexpr uint8_t sf = SCALE_FACTOR_E4M3_A; + auto fa = bit_cast(rv_nvfp4tof_s(a, sf, 0, nullptr)); + auto fb = bit_cast(rv_nvfp4tof_s(b, sf, 0, nullptr)); + auto fc = bit_cast(rv_nvfp4tof_s(c, sf, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftonvfp4_s(bit_cast(fd), sf, 0, nullptr); + } +}; + +template <> +struct muladd_t { + static int32_t eval(uint8_t a, uint8_t b, int32_t c) { + int32_t a_val = a & 0xF; + if (a & 0x8) { + a_val |= 0xFFFFFFF0; // sign extend + } + int32_t b_val = b & 0xF; + if (b & 0x8) { + b_val |= 0xFFFFFFF0; // sign extend + } + return a_val * b_val + c; + } +}; + +template <> +struct muladd_t { + static int32_t eval(uint8_t a, uint8_t b, int32_t c) { + int32_t a_val = a & 0xF; + int32_t b_val = b & 0xF; + return a_val * b_val + c; + } +}; + +template <> +struct muladd_t { + static int32_t eval(int8_t a, int8_t b, int32_t c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E8M0_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E8M0_B; + int32_t scale_exp_a = (int32_t)sf_a - 133; + float scale_factor_a = std::ldexp(1.0f, scale_exp_a); + int32_t scale_exp_b = (int32_t)sf_b - 133; + float scale_factor_b = std::ldexp(1.0f, scale_exp_b); + float product = (float)a * scale_factor_a * (float)b * scale_factor_b; + return (int32_t)product + c; + } +}; + +template +inline typename T::dtype generate_A_value() { + if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E8M0_A); + } else if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E4M3_A); + } else { + return Comparator::generate(); + } +} + +template +inline typename T::dtype generate_B_value() { + if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E8M0_B); + } else if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E4M3_B); + } else { + return Comparator::generate(); + } +} + +/////////////////////////////////////////////////////////////////////////////// + +using cfg = vt::wmma_config_t; + +using itype_t = typename vt::ITYPE::dtype; +using otype_t = typename vt::OTYPE::dtype; + + +// Dense CPU reference matmul (pruned A has zeros at masked positions) +static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS = subbytes ? (K * subbytes) : K; + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t n = 0; n < N; ++n) { + otype_t sum(0); + for (uint32_t k = 0; k < KS; ++k) { + auto a = data_accessor_t::read(A, m * KS + k); + auto b = data_accessor_t::read(B, k * N + n); + sum = muladd_t::eval(a, b, sum); + } + data_accessor_t::write(C, m * N + n, sum); + } + } +} + +// Get magnitude of element at given offset in A matrix (for pruning comparison) +static float get_element_magnitude(const itype_t *A, uint32_t offset) { + auto val = data_accessor_t::read(A, offset); + if constexpr (std::is_same_v || std::is_same_v) { + return std::abs(static_cast(static_cast(val))); + } else if constexpr (std::is_same_v) { + return static_cast(val); + } else if constexpr (std::is_same_v) { + int32_t sval = val & 0xF; + if (sval & 0x8) sval |= ~0xF; + return std::abs(static_cast(sval)); + } else if constexpr (std::is_same_v) { + return static_cast(val & 0xF); + } else if constexpr (std::is_same_v) { + return std::abs(bit_cast(rv_htof_s(val, 0, nullptr))); + } else if constexpr (std::is_same_v) { + return std::abs(bit_cast(rv_btof_s(val, 0, nullptr))); + } else { + return std::abs(static_cast(val)); + } +} + +// Prune matrix A with real 2:4 structured sparsity (top-2 by magnitude per group of 4) +// Zeros pruned elements in-place and stores per-group 4-bit masks +static void prune_2to4(itype_t *A, std::vector &masks, uint32_t M, uint32_t K) { + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t num_groups = KS / 4; + masks.resize(M * num_groups); + + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t g = 0; g < num_groups; ++g) { + uint32_t k_start = g * 4; + + // Get magnitudes + float mags[4]; + for (int p = 0; p < 4; ++p) { + mags[p] = get_element_magnitude(A, m * KS + k_start + p); + } + + // Find indices of top-2 by magnitude (ties broken by lower index) + int top[2] = {0, 1}; + if (mags[1] > mags[0]) { top[0] = 1; top[1] = 0; } + for (int p = 2; p < 4; ++p) { + if (mags[p] > mags[top[0]]) { + top[1] = top[0]; + top[0] = p; + } else if (mags[p] > mags[top[1]]) { + top[1] = p; + } + } + + // Build mask and zero pruned elements + uint8_t mask = (1 << top[0]) | (1 << top[1]); + masks[m * num_groups + g] = mask; + for (int p = 0; p < 4; ++p) { + if (!(mask & (1 << p))) { + data_accessor_t::write(A, m * KS + k_start + p, 0); + } + } + } + } +} + +// Compress pruned A (M x K) to M x K/2 using per-group masks +static void compress_2to4(itype_t *compressed, const itype_t *pruned_A, + const std::vector &masks, uint32_t M, uint32_t K) { + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t stride_comp = KS / 2; + uint32_t num_groups = KS / 4; + + for (uint32_t m = 0; m < M; ++m) { + uint32_t a_out = 0; + for (uint32_t g = 0; g < num_groups; ++g) { + uint32_t k_start = g * 4; + uint8_t mask = masks[m * num_groups + g]; + for (uint32_t k2 = 0; k2 < 4; ++k2) { + if (mask & (1 << k2)) { + auto val = data_accessor_t::read(pruned_A, m * KS + k_start + k2); + data_accessor_t::write(compressed, m * stride_comp + a_out, val); + a_out++; + } + } + } + } +} + +// Pack per-group masks into VX_tcu_meta SRAM layout +// Output: h_meta vector indexed as [tile_row][k_tile][NT * meta_cols words] +static void pack_metadata(std::vector &h_meta, + const std::vector &masks, + uint32_t M, uint32_t K) { + constexpr uint32_t I_RATIO = cfg::rtl_i_ratio; + constexpr uint32_t TC_K = cfg::tcK; + constexpr uint32_t TC_M = cfg::tcM; + constexpr uint32_t meta_row_w = TC_K * 2 * I_RATIO; + constexpr uint32_t mcols = cfg::meta_cols; + constexpr uint32_t half_k_steps = cfg::k_steps / 2; + + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t tileK_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t num_groups_per_row = KS / 4; + uint32_t elts_per_sparse_step = tileK_elem / half_k_steps; + + uint32_t num_tile_rows = M / cfg::tileM; + uint32_t num_k_tiles = K / cfg::tileK; + uint32_t per_k_tile_words = NUM_THREADS * mcols; + + h_meta.assign(num_tile_rows * num_k_tiles * per_k_tile_words, 0); + + for (uint32_t tr = 0; tr < num_tile_rows; ++tr) { + for (uint32_t kt = 0; kt < num_k_tiles; ++kt) { + uint32_t section_base = (tr * num_k_tiles + kt) * per_k_tile_words; + + for (uint32_t sm = 0; sm < cfg::m_steps; ++sm) { + for (uint32_t sk = 0; sk < half_k_steps; ++sk) { + uint32_t sram_row = sm * half_k_steps + sk; + + for (uint32_t i = 0; i < TC_M; ++i) { + uint32_t physical_row = tr * cfg::tileM + sm * TC_M + i; + uint32_t k_elem_start = kt * tileK_elem + sk * elts_per_sparse_step; + uint32_t groups_in_step = elts_per_sparse_step / 4; + + for (uint32_t g = 0; g < groups_in_step; ++g) { + uint32_t global_group = (k_elem_start / 4) + g; + uint8_t mask = masks[physical_row * num_groups_per_row + global_group]; + + for (int p = 0; p < 4; ++p) { + if (mask & (1 << p)) { + // Map element position to meta_row bit position + uint32_t elt = g * 4 + p; + uint32_t k_reg = elt / (2 * I_RATIO); + uint32_t pos_in_k = elt % (2 * I_RATIO); + uint32_t meta_bit; + if (pos_in_k < I_RATIO) { + meta_bit = k_reg * I_RATIO + pos_in_k; + } else { + meta_bit = (TC_K + k_reg) * I_RATIO + (pos_in_k - I_RATIO); + } + uint32_t block_bit = i * meta_row_w + meta_bit; + uint32_t word_idx = block_bit / 32; + uint32_t bit_idx = block_bit % 32; + h_meta[section_base + sram_row * mcols + word_idx] |= (1u << bit_idx); + } + } + } + } + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +const char *kernel_file = "kernel.vxbin"; + +uint32_t xm = 32; +uint32_t xn = 32; +uint32_t xk = 32; + +vx_device_h device = nullptr; +vx_buffer_h A_buffer = nullptr; +vx_buffer_h B_buffer = nullptr; +vx_buffer_h C_buffer = nullptr; +vx_buffer_h meta_buffer = nullptr; +vx_buffer_h cycles_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +std::string last_build_options; + +static void show_usage() { + std::cout << "Vortex Sgemm TCU Test." << std::endl; + std::cout << "Usage: [-m: m] [-n N] [-k: K] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "m:n:k:i:o:hs")) != -1) { + switch (c) { + case 'm': + xm = atoi(optarg); + break; + case 'n': + xn = atoi(optarg); + break; + case 'k': + xk = atoi(optarg); + break; + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(A_buffer); + vx_mem_free(B_buffer); + vx_mem_free(C_buffer); + vx_mem_free(meta_buffer); + vx_mem_free(cycles_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + + + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint64_t isa_flags; + RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags)); + bool has_ext = (isa_flags & VX_ISA_EXT_TCU) != 0; + if (!has_ext) { + std::cout << "TCU extension not supported!" << std::endl; + cleanup(); + return -1; + } + + uint64_t NT; + RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT)); + if (NT != NUM_THREADS) { + std::cout << "Error: device warp size (" << NT << ") must match NUM_THREADS=" << NUM_THREADS << "!" << std::endl; + return -1; + } + + uint32_t M = xm; + uint32_t N = xn; + uint32_t K = xk; + + if ((M % cfg::tileM) != 0) { + std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl; + return -1; + } + + if ((N % cfg::tileN) != 0) { + std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl; + return -1; + } + + if ((K % cfg::tileK) != 0) { + std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl; + return -1; + } + + size_t sizeA_full = M * K; + size_t sizeA = (M * K) / 2; + size_t sizeB = K * N; + size_t sizeC = M * N; + + std::cout << "input data type: " << vt::ITYPE::name << " (id=" << vt::ITYPE::id << ")" << std::endl; + std::cout << "output data type: " << vt::OTYPE::name << " (id=" << vt::OTYPE::id << ")" << std::endl; + std::cout << "WMMA Core Dimension: M=" << cfg::tcM << ", N=" << cfg::tcN << ", K=" << cfg::tcK << std::endl; + std::cout << "WMMA Tile Dimension: M=" << cfg::tileM << ", N=" << cfg::tileN << ", K=" << cfg::tileK << std::endl; + std::cout << "matrix A: " << M << "x" << K << std::endl; + std::cout << "matrix B: " << K << "x" << N << std::endl; + std::cout << "matrix C: " << M << "x" << N << std::endl; + + // set block size to warp size + kernel_arg.grid_dim[0] = N / cfg::tileN; + kernel_arg.grid_dim[1] = M / cfg::tileM; + kernel_arg.block_dim[0] = NT; // warp sizeb + kernel_arg.block_dim[1] = 1; + + // set matrix dimensions + kernel_arg.M = M; + kernel_arg.N = N; + kernel_arg.K = K; + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, sizeA * sizeof(itype_t), VX_MEM_READ, &A_buffer)); + RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); + RT_CHECK(vx_mem_alloc(device, sizeB * sizeof(itype_t), VX_MEM_READ, &B_buffer)); + RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); + RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); + RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + + // allocate metadata buffer per (tile_row, k_tile) + constexpr uint32_t meta_cols = cfg::meta_cols; + uint32_t num_tile_rows = M / cfg::tileM; + uint32_t num_k_tiles = K / cfg::tileK; + uint32_t meta_buf_entries = num_tile_rows * num_k_tiles * NUM_THREADS * meta_cols; + RT_CHECK(vx_mem_alloc(device, meta_buf_entries * sizeof(uint32_t), VX_MEM_READ, &meta_buffer)); + RT_CHECK(vx_mem_address(meta_buffer, &kernel_arg.meta_addr)); + + uint32_t num_blocks = kernel_arg.grid_dim[0] * kernel_arg.grid_dim[1]; + RT_CHECK(vx_mem_alloc(device, num_blocks * sizeof(uint32_t), VX_MEM_WRITE, &cycles_buffer)); + RT_CHECK(vx_mem_address(cycles_buffer, &kernel_arg.tcu_cycles_addr)); + + std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; + std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; + std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; + std::cout << "meta_addr=0x" << std::hex << kernel_arg.meta_addr << std::endl; + + // generate source data + // Generate full matrix A (M × K), prune in-place, then compress to M × K/2 + std::vector h_A_full(sizeA_full); + for (uint32_t i = 0; i < sizeA_full; ++i) { + h_A_full[i] = generate_A_value(); + } + std::vector masks; + prune_2to4(h_A_full.data(), masks, M, K); + std::vector h_A(sizeA); + compress_2to4(h_A.data(), h_A_full.data(), masks, M, K); + + std::vector h_B(sizeB); + for (uint32_t i = 0; i < sizeB; ++i) { + h_B[i] = generate_B_value(); + } + + // upload matrix A buffer + { + std::cout << "upload matrix A buffer" << std::endl; + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, sizeA * sizeof(itype_t))); + } + + // upload matrix B buffer + { + std::cout << "upload matrix B buffer" << std::endl; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + // sub-byte matrix B must be in col-major format + // we convert the 4-bit row-major to col-major here + std::vector h_B_col(sizeB); + convert_row_to_col_major_4bit(h_B_col.data(), N, 2 * K, (uint8_t*)h_B.data()); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B_col.data(), 0, sizeB)); + } else { + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, sizeB * sizeof(itype_t))); + } + } + + // upload metadata buffer (real masks from pruning) + { + std::cout << "upload metadata buffer" << std::endl; + std::vector h_meta; + pack_metadata(h_meta, masks, M, K); + RT_CHECK(vx_copy_to_dev(meta_buffer, h_meta.data(), 0, meta_buf_entries * sizeof(uint32_t))); + } + + // upload program + std::cout << "upload program" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + auto time_start = std::chrono::high_resolution_clock::now(); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + auto time_end = std::chrono::high_resolution_clock::now(); + double elapsed = std::chrono::duration_cast(time_end - time_start).count(); + printf("Elapsed time: %lg ms\n", elapsed); + + // download destination buffer + std::vector h_C(sizeC); + std::cout << "download destination buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); + + // download TCU K-loop cycle counts + { + std::vector h_cycles(num_blocks); + RT_CHECK(vx_copy_from_dev(h_cycles.data(), cycles_buffer, 0, num_blocks * sizeof(uint32_t))); + uint32_t max_cyc = 0; + for (uint32_t i = 0; i < num_blocks; ++i) { + if (h_cycles[i] > max_cyc) max_cyc = h_cycles[i]; + } + printf("TCU_CYCLES: max=%u (across %u blocks)\n", max_cyc, num_blocks); + } + + // === DEBUG: dump masks, metadata, compressed A for row 0 === + { + uint32_t subbytes_d = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS_d = subbytes_d ? (K * subbytes_d) : K; + uint32_t num_groups_d = KS_d / 4; + std::cout << "=== DEBUG: ITYPE::bits=" << vt::ITYPE::bits + << " I_RATIO=" << cfg::rtl_i_ratio + << " TC_K=" << cfg::tcK << " TC_M=" << cfg::tcM + << " meta_cols=" << cfg::meta_cols + << " tileK=" << cfg::tileK + << " k_steps=" << cfg::k_steps + << " half_k_steps=" << cfg::k_steps/2 + << std::endl; + + // Print masks for row 0 + std::cout << "Masks row 0:"; + for (uint32_t g = 0; g < num_groups_d && g < 8; ++g) { + printf(" g%u=0x%x", g, masks[0 * num_groups_d + g]); + } + std::cout << std::endl; + + // Print compressed A for row 0 (first 8 elements) + uint32_t stride_comp_d = KS_d / 2; + std::cout << "Compressed A row 0 (hex):"; + for (uint32_t k = 0; k < stride_comp_d && k < 16; ++k) { + auto val = data_accessor_t::read(h_A.data(), 0 * stride_comp_d + k); + printf(" 0x%x", (unsigned)val); + } + std::cout << std::endl; + + // Recompute and print metadata words + std::vector h_meta_dbg; + pack_metadata(h_meta_dbg, masks, M, K); + constexpr uint32_t mcols_d = cfg::meta_cols; + uint32_t per_k_words_d = NUM_THREADS * mcols_d; + std::cout << "Metadata words (tile_row=0, k_tile=0):"; + for (uint32_t w = 0; w < per_k_words_d; ++w) { + printf(" [%u]=0x%08x", w, h_meta_dbg[w]); + } + std::cout << std::endl; + + // Decode metadata bits for sram_row 0 (sm=0, sk=0) + // Each sram_row has mcols_d words = mcols_d*32 bits + // TC_M rows, each META_ROW_WIDTH bits + constexpr uint32_t meta_row_w_d = cfg::tcK * 2 * cfg::rtl_i_ratio; + std::cout << " sram_row0 decoded (TC_M=" << cfg::tcM << " rows, " << meta_row_w_d << " bits each):" << std::endl; + uint32_t sram0_word = h_meta_dbg[0]; + for (uint32_t i = 0; i < cfg::tcM; ++i) { + uint32_t row_bits = (sram0_word >> (i * meta_row_w_d)) & ((1u << meta_row_w_d) - 1); + printf(" TC_M row %u: bits=0x%x (binary:", i, row_bits); + for (int b = meta_row_w_d-1; b >= 0; --b) printf("%d", (row_bits >> b) & 1); + printf(")\n"); + } + + // Show what pruned A looks like for row 0 (full K) + std::cout << "Pruned A row 0 (full, first 16 hex):"; + for (uint32_t k = 0; k < KS_d && k < 16; ++k) { + auto val = data_accessor_t::read(h_A_full.data(), 0 * KS_d + k); + printf(" 0x%x", (unsigned)val); + } + std::cout << std::endl; + } + // === END DEBUG === + + // verify result + std::cout << "verify result" << std::endl; + int errors = 0; + { + std::vector h_ref(sizeC); + matmul_cpu(h_ref.data(), h_A_full.data(), h_B.data(), M, N, K); + + // Sparse reference: manually compute using compressed A + mask-selected B + // This mimics exactly what the hardware should do + uint32_t subbytes_v = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS_v = subbytes_v ? (K * subbytes_v) : K; + uint32_t stride_comp_v = KS_v / 2; + uint32_t num_groups_v = KS_v / 4; + std::vector h_sparse_ref(sizeC); + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t n = 0; n < N; ++n) { + otype_t sum(0); + uint32_t comp_idx = 0; + for (uint32_t g = 0; g < num_groups_v; ++g) { + uint8_t mask = masks[m * num_groups_v + g]; + // Extract first set and last set positions (matching VX_tcu_sel) + int first_set = -1, last_set = -1; + for (int p = 0; p < 4; ++p) { + if (mask & (1 << p)) { + if (first_set < 0) first_set = p; + last_set = p; + } + } + uint32_t k_base = g * 4; + // compressed A stores in ascending order: first_set then last_set + auto a_first = data_accessor_t::read(h_A.data(), m * stride_comp_v + comp_idx); + auto a_last = data_accessor_t::read(h_A.data(), m * stride_comp_v + comp_idx + 1); + auto b_first = data_accessor_t::read(h_B.data(), (k_base + first_set) * N + n); + auto b_last = data_accessor_t::read(h_B.data(), (k_base + last_set) * N + n); + sum = muladd_t::eval(a_first, b_first, sum); + sum = muladd_t::eval(a_last, b_last, sum); + comp_idx += 2; + } + data_accessor_t::write(h_sparse_ref.data(), m * N + n, sum); + } + } + + // Compare sparse ref with dense ref (should match) + int sparse_ref_errors = 0; + for (uint32_t i = 0; i < sizeC; ++i) { + if (!Comparator::compare(h_sparse_ref[i], h_ref[i], i, sparse_ref_errors)) { + if (sparse_ref_errors <= 5) { + printf(" sparse_ref[%u]=%f vs cpu_ref[%u]=%f\n", i, + static_cast(h_sparse_ref[i]), i, + static_cast(h_ref[i])); + } + ++sparse_ref_errors; + } + } + if (sparse_ref_errors > 0) { + printf("WARNING: sparse_ref vs cpu_ref: %d / %u mismatches!\n", sparse_ref_errors, sizeC); + } else { + printf("sparse_ref vs cpu_ref: ALL MATCH\n"); + } + + // Compare GPU output with sparse ref + int gpu_vs_sparse = 0; + for (uint32_t i = 0; i < sizeC; ++i) { + if (!Comparator::compare(h_C[i], h_sparse_ref[i], i, gpu_vs_sparse)) { + if (gpu_vs_sparse <= 5) { + printf(" gpu[%u]=%f vs sparse_ref[%u]=%f\n", i, + static_cast(h_C[i]), i, + static_cast(h_sparse_ref[i])); + } + ++gpu_vs_sparse; + } + } + if (gpu_vs_sparse > 0) { + printf("GPU vs sparse_ref: %d / %u mismatches\n", gpu_vs_sparse, sizeC); + } else { + printf("GPU vs sparse_ref: ALL MATCH\n"); + } + + // Print first few entries for manual inspection + printf("First 8 entries: cpu_ref / sparse_ref / gpu\n"); + for (uint32_t i = 0; i < 8 && i < sizeC; ++i) { + printf(" [%u] %f / %f / %f\n", i, + static_cast(h_ref[i]), + static_cast(h_sparse_ref[i]), + static_cast(h_C[i])); + } + + for (uint32_t i = 0; i < h_ref.size(); ++i) { + if (!Comparator::compare(h_C[i], h_ref[i], i, errors)) { + ++errors; + } + } + } + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " / " << sizeC << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return errors; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +} \ No newline at end of file diff --git a/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp b/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp new file mode 100644 index 000000000..63cdba487 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp @@ -0,0 +1,1076 @@ +#include +#include +#include +#include +#include +#include + +#define ENABLE_SPARSITY true +// Include random header only when sparsity is enabled +#ifdef ENABLE_SPARSITY +#include +#endif + +struct int4_t { + uint8_t data; +}; + +using float32_t = float; + +// ============================================================================ +// Configuration Macros +// ============================================================================ +#ifndef NUM_THREADS +#define NUM_THREADS 8 // Should be 32 for paper accuracy +#endif + +#ifndef XLENB +#define XLENB 4 +#endif + +#ifndef ITYPE +#define ITYPE int16_t +#endif + +#ifndef OTYPE +#define OTYPE int32_t +#endif + +#ifndef DPLEN +#define DPLEN 0 +#endif + +// ============================================================================ +// Debug Output Macros +// ============================================================================ +#ifdef NDEBUG +#define DBG_PRINT(fmt, ...) +#else +#define DBG_PRINT(fmt, ...) \ + do { \ + fprintf(stderr, fmt, __VA_ARGS__); \ + } while (0) +#endif + +#ifdef NDEBUG +class NullStream { +public: + template NullStream &operator<<(const T &) { return *this; } + NullStream &operator<<(std::ostream &(*)(std::ostream &)) { return *this; } + void flush() {} + static NullStream &instance() { + static NullStream null_stream; + return null_stream; + } +}; +#define dbg_out NullStream::instance() +#else +#define dbg_out std::cout +#endif + +template +struct DebugPrint; + +// ============================================================================ +// WMMA Configuration Template +// ============================================================================ +template +struct wmma_config_t { +private: + static constexpr uint32_t clog2(uint32_t x) { + return (x < 2) ? 0 : (1 + clog2(x / 2)); + } + static constexpr uint32_t tile_cap = NT * NR; + static constexpr uint32_t lg_tile_cap = clog2(tile_cap); + static constexpr uint32_t tile_en = lg_tile_cap / 2; + static constexpr uint32_t tile_em = lg_tile_cap - tile_en; + + static constexpr uint32_t block_cap = NT; + static constexpr uint32_t lg_block_cap = clog2(block_cap); + static constexpr uint32_t block_en = lg_block_cap / 2; + static constexpr uint32_t block_em = lg_block_cap - block_en; + +public: + static_assert(XB >= 0 && XB <= 8, "invalid XB value!"); + + static constexpr uint32_t i_ratio = XB / sizeof(It); + static constexpr uint32_t o_ratio = XB / sizeof(Ot); + static_assert(i_ratio * sizeof(It) == XB, "XB must be multiple of sizeof(It)"); + static_assert(o_ratio * sizeof(Ot) == XB, "XB must be multiple of sizeof(Ot)"); + + static constexpr uint32_t NumThreads = NT; + static constexpr uint32_t NumRegs = NR; + + static constexpr uint32_t xtileM = 1u << tile_em; + static constexpr uint32_t xtileN = 1u << tile_en; + static constexpr uint32_t xtileK = tile_cap / ((xtileM > xtileN) ? xtileM : xtileN); + + static constexpr uint32_t tcM = 1u << block_em; + static constexpr uint32_t tcN = 1u << block_en; + static constexpr uint32_t tcK = (DP != 0) ? DP : (block_cap / ((tcM > tcN) ? tcM : tcN)); + + static constexpr uint32_t m_steps = xtileM / tcM; + static constexpr uint32_t n_steps = xtileN / tcN; + static constexpr uint32_t k_steps = xtileK / tcK; + + static constexpr uint32_t a_block_size = tcM * tcK; + static constexpr uint32_t a_sub_blocks = block_cap / a_block_size; + static constexpr uint32_t a_sub_steps = m_steps / a_sub_blocks; + +#ifdef ENABLE_SPARSITY + // For 2:4 sparsity, B needs to provide both potential values + static constexpr uint32_t SPARSITY_RATIO = 2; + static constexpr uint32_t b_block_size = tcK * tcN * SPARSITY_RATIO; + static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; + static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; +#else + // Dense mode: standard B block configuration + static constexpr uint32_t b_block_size = tcK * tcN; + static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; + static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; +#endif + + static constexpr uint32_t NRA = (xtileM * xtileK) / NT; + static constexpr uint32_t NRB = (xtileN * xtileK) / NT; + static constexpr uint32_t NRC = (xtileM * xtileN) / NT; + + static constexpr uint32_t tileM = xtileM; + static constexpr uint32_t tileN = xtileN; + static constexpr uint32_t tileK = xtileK * i_ratio; + + static_assert(a_sub_steps != 0, "tcK is too small for tile A"); + static_assert(b_sub_steps != 0, "tcK is too small for tile B"); + + static_assert((xtileM * xtileK <= tile_cap), "xtileM * xtileK <= tile_cap"); + static_assert((xtileN * xtileK <= tile_cap), "xtileN * xtileK <= tile_cap"); + static_assert((xtileM * xtileN <= tile_cap), "xtileM * xtileN <= tile_cap"); + + static_assert((tcM * tcK <= block_cap), "tcM * tcK <= block_cap"); + static_assert((tcN * tcK <= block_cap), "tcN * tcK <= block_cap"); + static_assert((tcM * tcN <= block_cap), "tcM * tcN <= block_cap"); + + static_assert((xtileM % tcM) == 0, "M,m divisibility"); + static_assert((xtileN % tcN) == 0, "N,n divisibility"); + static_assert((xtileK % tcK) == 0, "K,k divisibility"); + + using vector_t = std::conditional_t<(XB == 1), uint8_t, + std::conditional_t<(XB == 2), uint16_t, + std::conditional_t<(XB == 4), uint32_t, uint64_t>>>; + using input_t = It; + using output_t = Ot; +}; + +// ============================================================================ +// Utility Types +// ============================================================================ +template +struct raw_unsigned { + static_assert( + sizeof(T) == 1 || sizeof(T) == 2 || + sizeof(T) == 4 || sizeof(T) == 8, + "raw_unsigned_t only supports types of size 1, 2, 4 or 8 bytes" + ); + + using type = std::conditional_t< + sizeof(T) == 1, uint8_t, + std::conditional_t< + sizeof(T) == 2, uint16_t, + std::conditional_t< + sizeof(T) == 4, uint32_t, + uint64_t + > + > + >; +}; + +template +using raw_unsigned_t = typename raw_unsigned::type; + +// ============================================================================ +// Pack Row Function +// ============================================================================ +template +D pack_row(const S *base, uint32_t ldm) { + static_assert(sizeof(D) % sizeof(S) == 0, "D must be a multiple of S"); + constexpr uint32_t count = sizeof(D) / sizeof(S); + using US = raw_unsigned_t; + D packed(0); + auto src = base; + for (uint32_t i = 0; i < count; ++i) { + US bits; + bits = *reinterpret_cast(src); + D elem = static_cast(bits); + packed |= (elem << (i * (8u * sizeof(S)))); + src += ldm; + } + return packed; +} + +// ============================================================================ +// Vector Register Type +// ============================================================================ +template +struct vector_t { +private: + std::array data_; + +public: + vector_t() = default; + + vector_t(T value) { + data_.fill(value); + } + + T* data() { + return data_.data(); + } + + const T* data() const { + return data_.data(); + } + + T& operator[](size_t idx) { + assert(idx < N); + return data_[idx]; + } + + const T& operator[](size_t idx) const { + assert(idx < N); + return data_[idx]; + } + + friend std::ostream &operator<<(std::ostream &os, const vector_t &v) { + os << std::hex << "{"; + for (size_t i = 0; i < N; ++i) { + if (i != 0) { + os << ", "; + } + os << "0x" << +v.data_[i]; + } + os << "}" << std::dec; + return os; + } +}; + +// ============================================================================ +// 2D Array Type +// ============================================================================ +template +struct array2d_t { +private: + std::array data_; + +public: + T* data() { + return data_.data(); + } + + const T* data() const { + return data_.data(); + } + + T &operator()(int row, int col) { + assert(row >= 0 && row < R); + assert(col >= 0 && col < C); + return data_[row * C + col]; + } + + const T &operator()(int row, int col) const { + assert(row >= 0 && row < R); + assert(col >= 0 && col < C); + return data_[row * C + col]; + } + + friend std::ostream &operator<<(std::ostream &os, const array2d_t &v) { + os << "{"; + for (size_t j = 0; j < R; ++j) { + if (j != 0) { + os << ", "; + } + os << "{"; + for (size_t i = 0; i < C; ++i) { + if (i != 0) { + os << ", "; + } + os << +v(j,i); + } + os << "}"; + } + os << "}"; + return os; + } +}; + +// ============================================================================ +// WMMA Implementation (Dense or Sparse based on ENABLE_SPARSITY) +// ============================================================================ +template +class WMMA { +private: + // Configuration constants + static constexpr uint32_t tileM = Config::tileM; + static constexpr uint32_t tileN = Config::tileN; + static constexpr uint32_t tileK = Config::tileK; + + static constexpr uint32_t tcM = Config::tcM; + static constexpr uint32_t tcN = Config::tcN; + static constexpr uint32_t tcK = Config::tcK; + + static constexpr uint32_t NT = Config::NumThreads; + static constexpr uint32_t NRA = Config::NRA; + static constexpr uint32_t NRB = Config::NRB; + static constexpr uint32_t NRC = Config::NRC; + + static constexpr uint32_t m_steps = Config::m_steps; + static constexpr uint32_t n_steps = Config::n_steps; + static constexpr uint32_t k_steps = Config::k_steps; + + static constexpr uint32_t a_block_size = Config::a_block_size; + static constexpr uint32_t a_sub_blocks = Config::a_sub_blocks; + static constexpr uint32_t a_sub_steps = Config::a_sub_steps; + + static constexpr uint32_t b_block_size = Config::b_block_size; + static constexpr uint32_t b_sub_blocks = Config::b_sub_blocks; + static constexpr uint32_t b_sub_steps = Config::b_sub_steps; + + static constexpr uint32_t i_ratio = Config::i_ratio; + static constexpr uint32_t o_ratio = Config::o_ratio; + +#ifdef ENABLE_SPARSITY + // Sparsity-specific constants + static constexpr uint32_t SPARSITY_N = 2; // 2 non-zero elements + static constexpr uint32_t SPARSITY_M = 4; // out of 4 elements (2:4 sparsity) + static constexpr uint32_t METADATA_LANES = Config::NumThreads / 4 / sizeof(typename Config::input_t); // Lanes 0,1 hold metadata for NT8, int8_t, 8Registers; for int16_t, NT=8, 4Registers, only lane 0 holds metadata + static constexpr uint32_t COMPRESSION_RATE = SPARSITY_M / SPARSITY_N; // 2x compression +#endif + + using Xt = typename Config::vector_t; + using It = typename Config::input_t; + using Ot = typename Config::output_t; + + using Vreg = vector_t; + + using FragA = array2d_t; + using FragB = array2d_t; + using FragC = array2d_t; + using FragD = array2d_t; + + // Matrix fragments + FragA fragA_; + FragB fragB_; + FragC fragC_; + FragD fragD_; + +#ifdef ENABLE_SPARSITY + // Sparsity-specific data structures + using FragA_meta = array2d_t; + + FragA fragA_compressed_; // Compressed matrix A (50% storage) + FragA_meta fragA_meta_; // Metadata: 1 = non-zero, 0 = pruned + static constexpr uint32_t META_ARRAY_SIZE = (tileM * tileK) / 32; //Total meta: tileM*tileK, each RISC-V register holds 32 bits + vector_t packed_bit_meta_; // Packed bitmap metadata. NT = 8, int8, 8REGS, MetaThreads = 2; int16, 4Regs, MetaThreads = 1 +#endif + + FragD fragRef_; + + uint32_t loop_iteration_count_; // Counter for total loop iterations + + // ======================================================================== + // Sparsity Helper Functions (only compiled when ENABLE_SPARSITY is defined) + // ======================================================================== +#ifdef ENABLE_SPARSITY + // Apply 2:4 structured pruning pattern + void apply_2_4_pruning(std::mt19937 &gen) { + std::vector masks = {1, 1, 0, 0}; // 2 ones, 2 zeros + + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK / SPARSITY_M; ++c) { + // Shuffle the mask for this group of 4 elements + std::shuffle(masks.begin(), masks.end(), gen); + + // Apply mask to each element in the group + for (uint32_t c_4 = 0; c_4 < SPARSITY_M; ++c_4) { + uint32_t col = c * SPARSITY_M + c_4; + if (masks[c_4] == 0) { + fragA_(r, col) = 0; + fragA_meta_(r, col) = 0; + } else { + fragA_meta_(r, col) = 1; + } + } + } + } + } + + // Compress matrix A by removing zeros + void compress_matrix_A() { + // Initialize compressed matrix to zero + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + fragA_compressed_(r, c) = 0; + } + } + + // Pack non-zero elements into compressed format + uint32_t comp_cnt = 0; + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + if (fragA_meta_(r, c) == 1) { + uint32_t comp_r = comp_cnt / (tileK / COMPRESSION_RATE); + uint32_t comp_c = comp_cnt % (tileK / COMPRESSION_RATE); + fragA_compressed_(comp_r, comp_c) = fragA_(r, c); + comp_cnt++; + } + } + } + } + + // Pack metadata into compact bitmap format + void pack_metadata_bitmap() { + constexpr uint32_t ELEMENTS_PER_ROW = tcK * i_ratio * COMPRESSION_RATE; + constexpr uint32_t ROWS_PER_CHUNK = tcM / COMPRESSION_RATE * sizeof(It); + + constexpr uint32_t k_steps_compressed = k_steps / COMPRESSION_RATE; + constexpr uint32_t num_chunks = COMPRESSION_RATE / sizeof(It); + + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t k = 0; k < k_steps / COMPRESSION_RATE; ++k) { + for (uint32_t chunk = 0; chunk < COMPRESSION_RATE / sizeof(It); ++chunk) { + uint32_t tmp_bit = 0; + + // Pack metadata for this chunk + for (uint32_t r_i = 0; r_i < ROWS_PER_CHUNK; ++r_i) { + for (uint32_t c_i = 0; c_i < ELEMENTS_PER_ROW; ++c_i) { + uint32_t row = r_i + chunk * ROWS_PER_CHUNK + m * tcM; + uint32_t col = c_i + k * ELEMENTS_PER_ROW; + + if (fragA_meta_(row, col) == 1) { + uint32_t bit_pos = 31 - (c_i + r_i * ELEMENTS_PER_ROW); + tmp_bit |= (1ULL << bit_pos); + } + } + } + uint32_t idx; + // + if(sizeof(It) == 1 || sizeof(It) == 2){ + idx = chunk + k * num_chunks + m * k_steps_compressed * num_chunks; + }else{ + static_assert(sizeof(It) == 1 || sizeof(It) == 2, "Only int8_t and int16_t are supported for sparsity"); + } + packed_bit_meta_[idx] = tmp_bit; + } + } + } + } + + // Extract bitmap for a specific row + uint16_t extract_row_metadata_int8_t(const Vreg &va_meta, uint32_t row_idx) const { + static_assert(sizeof(It) == 1, "int8_t extractor requires sizeof(It)==1"); + uint32_t meta_reg_idx = row_idx / COMPRESSION_RATE; + bool is_upper_half = (row_idx % COMPRESSION_RATE) == 0; + return is_upper_half ? + static_cast(va_meta[meta_reg_idx] >> 16) : + static_cast(va_meta[meta_reg_idx]); + } + + uint8_t extract_row_metadata_int16_t(const Vreg &va_meta, uint32_t row_idx) const { + constexpr uint32_t ELEMENTS_PER_ROW = COMPRESSION_RATE * (tcK * i_ratio); // = 8 + constexpr uint32_t ROWS_PER_CHUNK = 32 / ELEMENTS_PER_ROW; // = 4 + constexpr uint32_t ROW_MASK = 0xFF; //masks 8bits + + uint32_t meta_reg_idx = row_idx / ROWS_PER_CHUNK; // /4 + uint32_t which_part = row_idx % ROWS_PER_CHUNK; //%4 + uint32_t shift = 32 - (which_part + 1) * ELEMENTS_PER_ROW; + + uint32_t word = (va_meta[meta_reg_idx]); + return static_cast((word >> shift) & ROW_MASK); + + } + + // Gather B column elements based on A's sparsity pattern + void gather_sparse_B_column( + It *b_collected, + const Xt *b_col_0, + const Xt *b_col_1, + uint16_t a_row_meta) const { + //for ITYPE=int16_t, a_row_meta is uint8_t + //dbg_out << " [gather_sparse_B_column] a_row_meta=0x"<< std::hex << +a_row_meta << std::dec << "\n"; + + constexpr uint32_t TOTAL_ELEMENTS = tcK * i_ratio; + uint32_t collect_idx = 0; + + static_assert(sizeof(It) == 1 || sizeof(It) == 2, "Only int8_t and int16_t are supported for sparsity"); + uint32_t b_Mask = (uint32_t{1} << (8 * sizeof(It))) - 1; // 0xFF for int8, 0xFFFF for int16 + + // Gather from first half based on upper bits of metadata + for (uint32_t bit_idx = 0; bit_idx < TOTAL_ELEMENTS; ++bit_idx) { + uint32_t bit_pos = TOTAL_ELEMENTS * SPARSITY_N - bit_idx - 1; + if ((a_row_meta & (1 << bit_pos)) != 0) { + //dbg_out << " bit 1 at"<< " bit_idx=" << bit_idx << " bit_pos=" << bit_pos << "\n"; + uint32_t element_idx = bit_idx / i_ratio; + //dbg_out << " Gathering element_idx=" << element_idx << "\n"; + uint32_t byte_pos = (bit_idx % i_ratio) * 8 * sizeof(It); + //dbg_out << " byte_pos=" << byte_pos << "\n"; + b_collected[collect_idx++] = + static_cast((b_col_0[element_idx] >> byte_pos) & b_Mask); + //dbg_out << " " << +b_collected[collect_idx-1]; + } + } + + // Gather from second half based on lower bits of metadata + for (uint32_t bit_idx = 0; bit_idx < TOTAL_ELEMENTS; ++bit_idx) { + if (collect_idx >= TOTAL_ELEMENTS) break; + + uint32_t bit_pos = TOTAL_ELEMENTS - bit_idx - 1; + if ((a_row_meta & (1 << bit_pos)) != 0) { + //dbg_out << " bit 1 at"<< " bit_idx=" << bit_idx << " bit_pos=" << bit_pos << "\n"; + uint32_t element_idx = bit_idx / i_ratio; + //dbg_out << " Gathering element_idx=" << element_idx << "\n"; + uint32_t byte_pos = (bit_idx % i_ratio) * 8 * sizeof(It); + //dbg_out << " byte_pos=" << byte_pos << "\n"; + b_collected[collect_idx++] = + static_cast((b_col_1[element_idx] >> byte_pos) & b_Mask); + //dbg_out << " " << +b_collected[collect_idx-1]; + } + } + //dbg_out << "\n"; + } +#endif // ENABLE_SPARSITY + + // ======================================================================== + // Load/Store Operations (different implementations for dense/sparse) + // ======================================================================== + +#ifdef ENABLE_SPARSITY + // Sparse version of load_A + void load_A(vector_t &vR, uint32_t lane, uint32_t ldm, + const It *mdata, const vector_t &A_meta) { + uint32_t block_idx = lane / a_block_size; + uint32_t lane_in_block = lane % a_block_size; + uint32_t elem_row = lane_in_block / tcK; + uint32_t elem_col = lane_in_block % tcK; + + // Load compressed data into first half of registers + for (uint32_t r = 0; r < NRA / COMPRESSION_RATE; ++r) { + uint32_t block_m = (r / (k_steps / COMPRESSION_RATE)) * a_sub_blocks + block_idx; + uint32_t block_k = r % (k_steps / COMPRESSION_RATE); + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_k * tcK + elem_col; + auto base = mdata + row * ldm + col * i_ratio; + + assert(reinterpret_cast(base) % alignof(Xt) == 0 && + "Base pointer must be aligned"); + vR[r][lane] = *reinterpret_cast(base); + } + + // Load metadata into second half (only for metadata lanes) + if (lane < METADATA_LANES) { + for (uint32_t r = NRA / COMPRESSION_RATE; r < NRA; ++r) { + uint32_t meta_idx = (COMPRESSION_RATE * (r - NRA / COMPRESSION_RATE) + lane)/sizeof(It); + vR[r][lane] = A_meta.data()[meta_idx]; + /* dbg_out << "[load_A] lane=" << lane << " r=" << r + << " loads meta idx=" << meta_idx + << " value=0x" << std::hex << +vR[r][lane] << std::dec << "\n"; + */ + } + } else { + for (uint32_t r = NRA / COMPRESSION_RATE; r < NRA; ++r) { + vR[r][lane] = 0; + } + } + } +#else + // Dense version of load_A + void load_A(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / a_block_size; + uint32_t lane_in_block = lane % a_block_size; + uint32_t elem_row = lane_in_block / tcK; + uint32_t elem_col = lane_in_block % tcK; + //DBG_PRINT("[load_A] lane=%u block_idx=%u lane_in_block=%u elem=[%u,%u], src=%p-%p\n", + // lane, block_idx, lane_in_block, elem_row, elem_col, mdata, mdata + tileM * tileK); + + for (uint32_t r = 0; r < NRA; ++r) { + uint32_t block_m = (r / k_steps) * a_sub_blocks + block_idx; + uint32_t block_k = r % k_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_k * tcK + elem_col; + auto base = mdata + row * ldm + col * i_ratio; + + assert(reinterpret_cast(base) % alignof(Xt) == 0 && + "Base pointer must be aligned to sizeof(Xt)"); + vR[r][lane] = *reinterpret_cast(base); + //DBG_PRINT(" r=%u → block_m=%u block_k=%u → loads A[%u,%u] → %p → %u\n", + // r, block_m, block_k, row, col, base, vR[r][lane]); + } + } +#endif + +#ifdef ENABLE_SPARSITY + // Sparse version of load_B (loads 2x data for sparse B access) + void load_B(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / b_block_size; + uint32_t lane_in_block = lane % b_block_size; + uint32_t elem_col = lane_in_block / (tcK * COMPRESSION_RATE); + uint32_t elem_row = lane_in_block % (tcK * COMPRESSION_RATE); + + for (uint32_t r = 0; r < NRB; ++r) { + uint32_t block_k = r / b_sub_steps; + uint32_t block_n = (r % b_sub_steps) * b_sub_blocks + block_idx; + uint32_t row = block_k * tcK * COMPRESSION_RATE + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm * i_ratio + col; + + if constexpr (sizeof(Xt) == sizeof(It)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + vR[r][lane] = pack_row(base, ldm); + } + } + } +#else + // Dense version of load_B + void load_B(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / b_block_size; + uint32_t lane_in_block = lane % b_block_size; + uint32_t elem_col = lane_in_block / tcK; + uint32_t elem_row = lane_in_block % tcK; + //DBG_PRINT("[load_B] lane=%u block_idx=%u lane_in_block=%u elem=[%u,%u], src=%p-%p\n", + // lane, block_idx, lane_in_block, elem_row, elem_col, mdata, mdata + tileK * tileN); + + for (uint32_t r = 0; r < NRB; ++r) { + uint32_t block_k = r / b_sub_steps; + uint32_t block_n = (r % b_sub_steps) * b_sub_blocks + block_idx; + uint32_t row = block_k * tcK + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm * i_ratio + col; + + if constexpr (sizeof(Xt) == sizeof(It)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + vR[r][lane] = pack_row(base, ldm); + } + //DBG_PRINT(" r=%u → block_k=%u block_n=%u → loads B[%u,%u] → %p → %u\n", + // r, block_k, block_n, row, col, base, vR[r][lane]); + } + } +#endif + + void load_C(vector_t &vR, uint32_t lane, uint32_t ldm, const Ot *mdata) { + uint32_t elem_row = lane / tcN; + uint32_t elem_col = lane % tcN; + // DBG_PRINT("[load_C] lane=%u elem=[%u,%u], src=%p-%p\n", + // lane, elem_row, elem_col, mdata, mdata + tileM * tileN); + + for (uint32_t r = 0; r < NRC; ++r) { + uint32_t block_m = r / n_steps; + uint32_t block_n = r % n_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm + col; + + if constexpr (sizeof(Xt) == sizeof(Ot)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + Xt tmp(0); + *reinterpret_cast(&tmp) = *base; + vR[r][lane] = tmp; + } + // DBG_PRINT(" r=%u → block_m=%u block_n=%u → loads C[%u,%u] → %p → %u\n", + // r, block_m, block_n, row, col, base, vR[r][lane]); + } + } + + void store_D(Ot *mdata, uint32_t lane, uint32_t ldm, const vector_t &vR) { + uint32_t elem_row = lane / tcN; + uint32_t elem_col = lane % tcN; + + // DBG_PRINT("[store_D] lane=%u elem=[%u,%u], dst=%p-%p\n", + // lane, elem_row, elem_col, mdata, mdata + tileM * tileN); + + for (uint32_t r = 0; r < NRC; ++r) { + uint32_t block_m = r / n_steps; + uint32_t block_n = r % n_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm + col; + + if constexpr (sizeof(Xt) == sizeof(Ot)) { + *reinterpret_cast(base) = vR[r][lane]; + } else { + Xt tmp(vR[r][lane]); + *base = *reinterpret_cast(&tmp); + } + // DBG_PRINT(" r=%u → block_m=%u block_n=%u → store C[%u,%u] → %p → %u\n", + // r, block_m, block_n, row, col, base , vR[r][lane]); + } + } + + // ======================================================================== + // Core Computation Operations + // ======================================================================== + + // Fused Element-wise Dot Product + Xt FEDP(const Xt *a_row, const Xt *b_col, Xt c_val) const { + Ot acc(*reinterpret_cast(&c_val)); + auto a = reinterpret_cast(a_row); + auto b = reinterpret_cast(b_col); + for (uint32_t z = 0; z < tcK * i_ratio; ++z) { + auto a_val = static_cast(a[z]); + auto b_val = static_cast(b[z]); + acc = a_val * b_val + acc; + } + Xt ret(0); + *reinterpret_cast(&ret) = acc; + return ret; + } + +#ifdef ENABLE_SPARSITY + // Sparse Matrix Multiply-Accumulate micro-operation + Vreg MMA(uint32_t m, uint32_t n, const Vreg &va, const Vreg &va_meta, + const Vreg &vb, const Vreg &vc) { + uint32_t a_off = (m % a_sub_blocks) * a_block_size; + uint32_t b_off = (n % b_sub_blocks) * b_block_size; + + Vreg vd; + It b_col_collected[tcK * i_ratio]; + + for (uint32_t i = 0; i < tcM; ++i) { + for (uint32_t j = 0; j < tcN; ++j) { + auto a_row = &va[a_off + i * tcK]; + auto b_col_0 = &vb[b_off + j * tcK * COMPRESSION_RATE]; + auto b_col_1 = &vb[b_off + j * tcK * COMPRESSION_RATE + tcK]; + auto c = vc[i * tcN + j]; + + // Extract metadata for this row + uint32_t a_row_meta; + if constexpr (sizeof(It) == 1){ + a_row_meta = extract_row_metadata_int8_t(va_meta, i); + }else if (sizeof(It) == 2){ + a_row_meta = extract_row_metadata_int16_t(va_meta, i); + } + // Gather sparse B elements based on A's metadata + gather_sparse_B_column(b_col_collected, b_col_0, b_col_1, a_row_meta); + + // Compute dot product + auto d = FEDP(a_row, reinterpret_cast(b_col_collected), c); + vd[i * tcN + j] = d; + } + } + + return vd; + } +#else + // Dense Matrix Multiply-Accumulate micro-operation + Vreg MMA(uint32_t m, uint32_t n, const Vreg &va, const Vreg &vb, const Vreg &vc) { + uint32_t a_off = (m % a_sub_blocks) * a_block_size; + uint32_t b_off = (n % b_sub_blocks) * b_block_size; + + Vreg vd; + for (uint32_t i = 0; i < tcM; ++i) { + for (uint32_t j = 0; j < tcN; ++j) { + auto a_row = &va[a_off + i * tcK]; + auto b_col = &vb[b_off + j * tcK]; + auto c = vc[i * tcN + j]; + auto d = FEDP(a_row, b_col, c); + vd[i * tcN + j] = d; + } + } + + return vd; + } +#endif + +#ifdef ENABLE_SPARSITY + // Sparse matrix multiply-add operation + FragD mmadd(const FragA &A, const vector_t &A_meta, + const FragB &B, const FragC &C) { + FragD D; + vector_t vA; + vector_t vB; + vector_t vC, vD; + + dbg_out << "A=" << A << "\n"; + dbg_out << "B=" << B << "\n"; + dbg_out << "C=" << C << "\n"; + + // Load fragments into vector registers + for (uint32_t lane = 0; lane < NT; ++lane) { + load_A(vA, lane, tileK, A.data(), A_meta); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_B(vB, lane, tileN, B.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_C(vC, lane, tileN, C.data()); + } + + // Execute micro-operations + for (uint32_t k = 0; k < k_steps / COMPRESSION_RATE; ++k) { + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t n = 0; n < n_steps; ++n) { + loop_iteration_count_++; // Count loop iterations + uint32_t idxA = (m / a_sub_blocks) * (k_steps / COMPRESSION_RATE) + k; + uint32_t idxA_meta = idxA + NRA / COMPRESSION_RATE; + uint32_t idxB = (k * n_steps + n) / b_sub_blocks; + uint32_t idxC = m * n_steps + n; + + auto &va = vA[idxA]; + auto &va_meta = vA[idxA_meta]; + auto &vb = vB[idxB]; + auto &vc = (k != 0) ? vD[idxC] : vC[idxC]; + + auto vd = MMA(m, n, va, va_meta, vb, vc); + vD[idxC] = vd; + } + } + } + + // Store results back to fragment + for (uint32_t lane = 0; lane < NT; ++lane) { + store_D(D.data(), lane, tileN, vD); + } + + dbg_out << "D=" << D << "\n"; + return D; + } +#else + // Dense matrix multiply-add operation + FragD mmadd(const FragA &A, const FragB &B, const FragC &C) { + FragD D; + vector_t vA; + vector_t vB; + vector_t vC, vD; + + dbg_out << "A=" << A << "\n"; + dbg_out << "B=" << B << "\n"; + dbg_out << "C=" << C << "\n"; + + // per-lane load + for (uint32_t lane = 0; lane < NT; ++lane) { + load_A(vA, lane, tileK, A.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_B(vB, lane, tileN, B.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_C(vC, lane, tileN, C.data()); + } + + for (uint32_t i = 0; i < NRA; ++i) { + dbg_out << "vA" << i << "=" << vA[i] << "\n"; + } + for (uint32_t i = 0; i < NRB; ++i) { + dbg_out << "vB" << i << "=" << vB[i] << "\n"; + } + for (uint32_t i = 0; i < NRC; ++i) { + dbg_out << "vC" << i << "=" << vC[i] << "\n"; + } + + // micro-ops + for (uint32_t k = 0; k < k_steps; ++k) { + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t n = 0; n < n_steps; ++n) { + loop_iteration_count_++; // Count loop iterations + uint32_t idxA = (m / a_sub_blocks) * k_steps + k; + uint32_t idxB = (k * n_steps + n) / b_sub_blocks; + uint32_t idxC = m * n_steps + n; + + auto &va = vA[idxA]; + auto &vb = vB[idxB]; + auto &vc = (k != 0) ? vD[idxC] : vC[idxC]; + + auto vd = MMA(m, n, va, vb, vc); + + // dbg_out << "[mmadd] m=" << m << " n=" << n << " k=" << k + // << " → idxA=" << idxA << " idxB=" << idxB << " idxC=" << idxC + // << " va=" << va << " vb=" << vb << " vc=" << vc << " vd=" << vd << "\n"; + + vD[idxC] = vd; + } + } + } + + dbg_out.flush(); + + for (uint32_t i = 0; i < NRC; ++i) { + dbg_out << "vD" << i << "=" << vD[i] << "\n"; + } + + // per-lane store + for (uint32_t lane = 0; lane < NT; ++lane) { + store_D(D.data(), lane, tileN, vD); + } + + dbg_out << "D=" << D << "\n"; + return D; + } +#endif + +public: + // ======================================================================== + // Public Interface + // ======================================================================== + + void init() { + int x = 0; + + // Initialize matrix A with sequential values + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + fragA_(r, c) = x++; + } + } + +#ifdef ENABLE_SPARSITY + // Apply 2:4 structured sparsity + std::random_device rd; + std::mt19937 gen(rd()); + apply_2_4_pruning(gen); + + // Compress sparse matrix A + compress_matrix_A(); + + // Pack metadata into bitmap format + pack_metadata_bitmap(); +#endif + + // Initialize matrix B with sequential values + for (uint32_t r = 0; r < tileK; ++r) { + for (uint32_t c = 0; c < tileN; ++c) { + fragB_(r, c) = x++; + } + } + + // Initialize matrix C to zero + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileN; ++c) { + fragC_(r, c) = 0; + } + } + + // Compute reference result + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + Ot sum(0); + for (uint32_t k = 0; k < tileK; ++k) { + auto a = static_cast(fragA_(row, k)); + auto b = static_cast(fragB_(k, col)); + sum = a * b + sum; + } + fragRef_(row, col) = sum + fragC_(row, col); + } + } + } + + float verify() const { + if constexpr (std::is_integral_v) { + int32_t err(0); + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + auto curr = static_cast(fragD_(row, col)); + auto ref = static_cast(fragRef_(row, col)); + auto diff = std::abs(curr - ref); + err = std::max(err, diff); + } + } + return static_cast(err); + } else { + float err(0); + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + auto curr = static_cast(fragD_(row, col)); + auto ref = static_cast(fragRef_(row, col)); + auto diff = std::fabs(curr - ref); + err = std::max(err, diff); + } + } + return err; + } + } + + uint32_t get_loop_count() const { + return loop_iteration_count_; + } + + void run() { + loop_iteration_count_ = 0; // Initialize counter +#ifdef ENABLE_SPARSITY + fragD_ = mmadd(fragA_compressed_, packed_bit_meta_, fragB_, fragC_); +#else + fragD_ = mmadd(fragA_, fragB_, fragC_); +#endif + } +}; + +// ============================================================================ +// Main Test Driver +// ============================================================================ +using cfg = wmma_config_t< + NUM_THREADS, + 8, + XLENB, + OTYPE, + ITYPE, + DPLEN>; + +int main() { + WMMA wmma; + +#ifdef ENABLE_SPARSITY + std::cout << "=== Sparse Tensor Core Configuration (2:4 Structured Sparsity) ===\n"; +#else + std::cout << "=== Dense Tensor Core Configuration ===\n"; +#endif + + std::cout + << "tileM = " << cfg::tileM << "\n" + << "tileN = " << cfg::tileN << "\n" + << "tileK = " << cfg::tileK << "\n" + << "tcM = " << cfg::tcM << "\n" + << "tcN = " << cfg::tcN << "\n" + << "tcK = " << cfg::tcK << "\n" + << "m_steps = " << cfg::m_steps << "\n" + << "n_steps = " << cfg::n_steps << "\n" + << "k_steps = " << cfg::k_steps << "\n" + << "a_block_size = " << cfg::a_block_size << "\n" + << "a_sub_blocks = " << cfg::a_sub_blocks << "\n" + << "a_sub_steps = " << cfg::a_sub_steps << "\n" + << "b_block_size = " << cfg::b_block_size << "\n" + << "b_sub_blocks = " << cfg::b_sub_blocks << "\n" + << "b_sub_steps = " << cfg::b_sub_steps << "\n" + << "NRA = " << cfg::NRA << "\n" + << "NRB = " << cfg::NRB << "\n" + << "NRC = " << cfg::NRC << "\n" + << "\n"; + + wmma.init(); + wmma.run(); + + auto err = wmma.verify(); + bool passed = (err < 1e-4f); + + std::cout << "Total loop iterations: " << wmma.get_loop_count() << "\n" + << "Max abs error: " << err << "\n" + << (passed ? "PASSED!" : "FAILED!") << '\n'; + + return passed ? 0 : 1; +} + +// ============================================================================ +// Build Instructions +// ============================================================================ +// Dense mode (default): +// g++ -std=c++17 -O2 tensor_generic.cpp -o a.out +// +// Sparse mode (2:4 structured sparsity): +// g++ -std=c++17 -O2 -DENABLE_SPARSITY tensor_generic.cpp -o a.out +// +// Debug builds: +// g++ -std=c++17 -g tensor_generic.cpp -o a.out +// g++ -std=c++17 -g -DENABLE_SPARSITY tensor_generic.cpp -o a.out