From d73a042fc6f7f077b99d62f937ebfcce56269e95 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Thu, 5 Feb 2026 15:39:04 -0800 Subject: [PATCH 01/22] Add sgemm_tcu_struct_sparse test and update TCU files --- hw/rtl/tcu/VX_tcu_pkg.sv | 2 +- hw/rtl/tcu/VX_tcu_uops.sv | 10 +- kernel/include/vx_tensor.h | 16 +- sim/common/tensor_cfg.h | 2 +- .../sgemm_tcu_struct_sparse/Makefile | 16 + .../sgemm_tcu_struct_sparse/common.h | 27 + .../sgemm_tcu_struct_sparse/kernel.cpp | 56 + .../sgemm_tcu_struct_sparse/main.cpp | 910 ++++++++++++++ .../sgemm_tcu_struct_sparse/sparse_test.py | 44 + .../tensor_generic.cpp | 1076 +++++++++++++++++ 10 files changed, 2149 insertions(+), 10 deletions(-) create mode 100644 tests/regression/sgemm_tcu_struct_sparse/Makefile create mode 100644 tests/regression/sgemm_tcu_struct_sparse/common.h create mode 100644 tests/regression/sgemm_tcu_struct_sparse/kernel.cpp create mode 100644 tests/regression/sgemm_tcu_struct_sparse/main.cpp create mode 100644 tests/regression/sgemm_tcu_struct_sparse/sparse_test.py create mode 100644 tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index ebb30820ea..1a172e2674 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -74,7 +74,7 @@ package VX_tcu_pkg; localparam TCU_A_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_A_BLOCK_SIZE; // B micro-tiling - localparam TCU_B_BLOCK_SIZE = TCU_TC_K * TCU_TC_N; + localparam TCU_B_BLOCK_SIZE = (TCU_TC_K * TCU_TC_N)*2; // sparsity 2601223 localparam TCU_B_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE; // Register counts diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index 568f838802..3b06ea49bb 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -62,10 +62,14 @@ module VX_tcu_uops import end // Register offsets - wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); + // wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); + // wire [CTR_W-1:0] rs2_offset = ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; + // wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); + + wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index); wire [CTR_W-1:0] rs2_offset = ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); - + // Register calculations wire [4:0] rs1 = TCU_RA + 5'(rs1_offset); wire [4:0] rs2 = TCU_RB + 5'(rs2_offset); @@ -115,7 +119,7 @@ module VX_tcu_uops import done <= (TCU_UOPS == 1); end else if (busy && next) begin counter <= counter + ((TCU_UOPS > 1) ? 1 : 0); - done <= (counter == CTR_W'(TCU_UOPS-2)); + done <= (counter == CTR_W'((TCU_UOPS/2)-2)); // sparsity 2601223 busy <= ~done; end end diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index efb3647ef2..9caae4d786 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -117,7 +117,6 @@ namespace detail { return *reinterpret_cast(&result_u); } }; - } template (ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); + //dst.data[r] = *reinterpret_cast(ptr); + if (r < 4) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + // Zero for r=4,5,6,7 + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); + } } }); } else if constexpr (Frag::Use == matrix_b) { // Load column-major matrix B uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); uint32_t lane_in_blk = (cfg::b_block_size == NT) ? lane : (lane % cfg::b_block_size); - uint32_t block_col = (lane_in_blk / cfg::tcK) + (block_idx * cfg::tcN); - uint32_t block_row = (lane_in_blk % cfg::tcK) * i_ratio; + uint32_t block_col = (lane_in_blk / ((cfg::tcK)*2)) + (block_idx * cfg::tcN); + uint32_t block_row = (lane_in_blk % ((cfg::tcK)*2)) * i_ratio; uint32_t n_stride = cfg::b_sub_blocks * cfg::tcN; - uint32_t k_stride = cfg::tcK * i_ratio; + uint32_t k_stride = ((cfg::tcK)*2) * i_ratio; if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index e39c7fff2c..f46dbf4989 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -191,7 +191,7 @@ struct wmma_config_t { static constexpr uint32_t a_sub_blocks = block_cap / a_block_size; // number of A micro-tiles per register static constexpr uint32_t a_sub_steps = m_steps / a_sub_blocks; // number of A sub-steps per register - static constexpr uint32_t b_block_size = tcK * tcN; // size of B micro-tile + static constexpr uint32_t b_block_size = (tcK * tcN)*2; // size of B micro-tile static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; // number of B micro-tiles per register static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register diff --git a/tests/regression/sgemm_tcu_struct_sparse/Makefile b/tests/regression/sgemm_tcu_struct_sparse/Makefile new file mode 100644 index 0000000000..e2c7b0ee04 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/Makefile @@ -0,0 +1,16 @@ +ROOT_DIR := $(realpath ../../..) +include $(ROOT_DIR)/config.mk + +PROJECT := sgemm_tcu_struct_sparse + +SRC_DIR := $(VORTEX_HOME)/tests/regression/$(PROJECT) + +SRCS := $(SRC_DIR)/main.cpp $(SW_COMMON_DIR)/rvfloats.cpp $(SW_COMMON_DIR)/softfloat_ext.cpp + +VX_SRCS := $(SRC_DIR)/kernel.cpp + +CXXFLAGS += -I$(THIRD_PARTY_DIR)/softfloat/source/include + +LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a + +include ../common.mk \ No newline at end of file diff --git a/tests/regression/sgemm_tcu_struct_sparse/common.h b/tests/regression/sgemm_tcu_struct_sparse/common.h new file mode 100644 index 0000000000..a762a4fb2e --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/common.h @@ -0,0 +1,27 @@ +#ifndef _COMMON_H_ +#define _COMMON_H_ + +#include + +#ifndef NUM_THREADS +#define NUM_THREADS 4 +#endif + +#ifndef ITYPE +#define ITYPE fp16 +#endif + +#ifndef OTYPE +#define OTYPE fp32 +#endif + +typedef struct { + uint32_t grid_dim[2]; + uint32_t block_dim[2]; + uint32_t M, N, K; + uint64_t A_addr; + uint64_t B_addr; + uint64_t C_addr; +} kernel_arg_t; + +#endif diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp new file mode 100644 index 0000000000..0b92470a27 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -0,0 +1,56 @@ +#include "common.h" +#include +#include + +namespace vt = vortex::tensor; +using ctx = vt::wmma_context; + +void kernel_body(kernel_arg_t *__UNIFORM__ arg) { + auto pA = reinterpret_cast(arg->A_addr); + auto pB = reinterpret_cast(arg->B_addr); + auto pC = reinterpret_cast(arg->C_addr); + + uint32_t M = arg->M; + uint32_t N = arg->N; + uint32_t K = arg->K; + + ctx::fragment_a fragA; + ctx::fragment_b fragB; + ctx::fragment_acc fragC; + + // calculate tile row & column based on block index + uint32_t tile_row = blockIdx.y * ctx::tileM; + uint32_t tile_col = blockIdx.x * ctx::tileN; + + // Initialize accumulator tile to zero + ctx::fill_fragment(fragC, 0); + + for (int i = 0; i < (K)/2; i += (ctx::tileK)/2) { + auto pTileA = pA + tile_row * K + i; + + // Load A tile + ctx::load_matrix_sync(fragA, pTileA, K); + + // Load B tile + if constexpr (vt::ITYPE::bits < 8) { + // For sub-byte matrix B must be in col-major format + auto pTileB = pB + tile_col * K + i; + ctx::load_matrix_sync(fragB, pTileB, K); + } else { + auto pTileB = pB + i * N + tile_col; + ctx::load_matrix_sync(fragB, pTileB, N); + } + + // Matrix multiply-accumulate: c += a * b + ctx::mma_sync(fragC, fragA, fragB, fragC); + } + + // Store the computed C tile + auto pTileC = pC + tile_row * N + tile_col; + ctx::store_matrix_sync(pTileC, fragC, N); +} + +int main() { + auto arg = (kernel_arg_t *)csr_read(VX_CSR_MSCRATCH); + return vx_spawn_threads(2, arg->grid_dim, arg->block_dim, (vx_kernel_func_cb)kernel_body, arg); +} diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp new file mode 100644 index 0000000000..accb95a92c --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -0,0 +1,910 @@ +#include "common.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define FLOAT_ULP 6 +#define MAX_ERRORS 100 + +#define RT_CHECK(_expr) \ + do { \ + int _ret = _expr; \ + if (0 == _ret) \ + break; \ + printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \ + cleanup(); \ + exit(-1); \ + } while (false) + +using namespace vortex; +namespace vt = tensor; + +/////////////////////////////////////////////////////////////////////////////// + +static void convert_row_to_col_major_4bit(uint8_t *dst, uint32_t width, uint32_t height, const uint8_t *src) { + // Calculate output size and stride + uint32_t out_bytes = (width * height + 1) / 2; + memset(dst, 0, out_bytes); + uint32_t dst_stride = (height + 1) / 2; // Bytes per column in output + + // For each column in source (which becomes row in destination) + for (uint32_t c = 0; c < width; ++c) { + uint32_t base = c * dst_stride; + + // For each row in source (which becomes column in destination) + for (uint32_t r = 0; r < height; r += 2) { + // Calculate source indices (row-major) + uint32_t idx_even = r * width + c; + uint32_t idx_odd = (r + 1) * width + c; + + // Extract nibbles - consistent with data_accessor_t + uint8_t b_even = src[idx_even / 2]; + uint8_t b_odd = (r + 1 < height) ? src[idx_odd / 2] : 0; + + uint8_t nib_even = (idx_even & 1) ? (b_even >> 4) : (b_even & 0x0F); + uint8_t nib_odd = (r + 1 < height) + ? ((idx_odd & 1) ? (b_odd >> 4) : (b_odd & 0x0F)) + : 0; + + // Pack into destination: even row in low nibble, odd row in high nibble + dst[base + r / 2] = (nib_odd << 4) | nib_even; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +template +struct data_accessor_t { + using Type = typename T::dtype; + static Type read(const Type *ptr, uint32_t offset) { + return ptr[offset]; + } + static void write(Type *ptr, uint32_t offset, Type value) { + ptr[offset] = value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble + } + static void write(uint8_t *ptr, uint32_t offset, int32_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // to nibble + } + static void write(uint8_t *ptr, uint32_t offset, int32_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +template <> +struct data_accessor_t { + static uint8_t read(const uint8_t *ptr, uint32_t offset) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t value8 = ptr[row_off]; + return odd ? (value8 >> 4) : (value8 & 0x0f); // extract nibble + } + static void write(uint8_t *ptr, uint32_t offset, uint8_t value) { + uint32_t row_off = offset / 2; + bool odd = offset & 0x1; + uint8_t old_value = ptr[row_off]; + uint8_t new_value = odd ? ((old_value & 0x0f) | (value << 4)) + : ((old_value & 0xf0) | (value & 0x0f)); + ptr[offset / 2] = new_value; + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +class Comparator {}; + +template <> +class Comparator { +public: + static int8_t generate() { + return (int8_t)rand(); + } + static bool compare(int8_t a, int8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); // store 2 nibbles in a byte + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return (uint8_t)rand(); // store 2 nibbles in a byte + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static int8_t generate() { + return (int8_t)(rand() % 256 - 128); + } + static bool compare(int8_t a, int8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static int32_t generate() { + return (int32_t)rand(); + } + static bool compare(int32_t a, int32_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint16_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoh_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint16_t a, uint16_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint16_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftob_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint16_t a, uint16_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoe4m3_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftoe5m2_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint32_t generate() { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftotf32_s(bit_cast(fvalue), 0, nullptr); + } + static bool compare(uint32_t a, uint32_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +// TODO: temp arbitrarily hardcoded scale factors +constexpr uint8_t SCALE_FACTOR_E8M0_A = 129; // val = 4, bias = 127 +constexpr uint8_t SCALE_FACTOR_E8M0_B = 131; // val = 16 +constexpr uint8_t SCALE_FACTOR_E4M3_A = 0x41; // val = 2.25, bias = 7 +constexpr uint8_t SCALE_FACTOR_E4M3_B = 0x33; // val = 0.6875 + +template <> +class Comparator { +public: + static uint8_t generate() { + return generate_with_scale(SCALE_FACTOR_E8M0_A); + } + + static uint8_t generate_with_scale(uint8_t scale_factor) { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftomxfp8_s(bit_cast(fvalue), scale_factor, 0, nullptr); + } + + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static uint8_t generate() { + return generate_with_scale(SCALE_FACTOR_E4M3_A); + } + + static uint8_t generate_with_scale(uint8_t scale_factor) { + auto fvalue = float(rand()) / RAND_MAX; + return rv_ftonvfp4_s(bit_cast(fvalue), scale_factor, 0, nullptr); + } + + static bool compare(uint8_t a, uint8_t b, int index, int errors) { + if (a != b) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=0x%x, actual=0x%x\n", index, b, a); + } + return false; + } + return true; + } +}; + +template <> +class Comparator { +public: + static float generate() { + return static_cast(rand()) / RAND_MAX; + } + static bool compare(float a, float b, int index, int errors) { + if constexpr (std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value) { + if (a == 0.0f && b == 0.0f) { + return true; + } + //relative error tolerance + auto diff = std::abs((a - b)/b); + if (diff < 0.01f) { + return true; + } + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=%f, actual=%f\n", index, b, a); + } + return false; + } else { + union fi_t { + float f; + int32_t i; + }; + fi_t fa, fb; + fa.f = a; + fb.f = b; + auto d = std::abs(fa.i - fb.i); + if (d > FLOAT_ULP) { + if (errors < MAX_ERRORS) { + printf("*** error: [%d] expected=%f, actual=%f\n", index, fb.f, fa.f); + } + return false; + } + return true; + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct muladd_t { + using stype = typename S::dtype; + using dtype = typename D::dtype; + static dtype eval(stype a, stype b, dtype c) { + return static_cast(a) * static_cast(b) + c; + } +}; + +template <> +struct muladd_t { + static float eval(uint16_t a, uint16_t b, float c) { + auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { + auto fa = bit_cast(rv_htof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_htof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_htof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoh_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint16_t a, uint16_t b, float c) { + auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint16_t eval(uint16_t a, uint16_t b, uint16_t c) { + auto fa = bit_cast(rv_btof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_btof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_btof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftob_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + auto fa = bit_cast(rv_e4m3tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e4m3tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + auto fa = bit_cast(rv_e4m3tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e4m3tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_e4m3tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoe4m3_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + auto fa = bit_cast(rv_e5m2tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e5m2tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + auto fa = bit_cast(rv_e5m2tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_e5m2tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_e5m2tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftoe5m2_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint32_t a, uint32_t b, float c) { + auto fa = bit_cast(rv_tf32tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_tf32tof_s(b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint32_t eval(uint32_t a, uint32_t b, uint32_t c) { + auto fa = bit_cast(rv_tf32tof_s(a, 0, nullptr)); + auto fb = bit_cast(rv_tf32tof_s(b, 0, nullptr)); + auto fc = bit_cast(rv_tf32tof_s(c, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftotf32_s(bit_cast(fd), 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E8M0_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E8M0_B; + auto fa = bit_cast(rv_mxfp8tof_s(a, sf_a, 0, nullptr)); + auto fb = bit_cast(rv_mxfp8tof_s(b, sf_b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + constexpr uint8_t sf = SCALE_FACTOR_E8M0_A; + auto fa = bit_cast(rv_mxfp8tof_s(a, sf, 0, nullptr)); + auto fb = bit_cast(rv_mxfp8tof_s(b, sf, 0, nullptr)); + auto fc = bit_cast(rv_mxfp8tof_s(c, sf, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftomxfp8_s(bit_cast(fd), sf, 0, nullptr); + } +}; + +template <> +struct muladd_t { + static float eval(uint8_t a, uint8_t b, float c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E4M3_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E4M3_B; + auto fa = bit_cast(rv_nvfp4tof_s(a, sf_a, 0, nullptr)); + auto fb = bit_cast(rv_nvfp4tof_s(b, sf_b, 0, nullptr)); + return fa * fb + c; + } +}; + +template <> +struct muladd_t { + static uint8_t eval(uint8_t a, uint8_t b, uint8_t c) { + constexpr uint8_t sf = SCALE_FACTOR_E4M3_A; + auto fa = bit_cast(rv_nvfp4tof_s(a, sf, 0, nullptr)); + auto fb = bit_cast(rv_nvfp4tof_s(b, sf, 0, nullptr)); + auto fc = bit_cast(rv_nvfp4tof_s(c, sf, 0, nullptr)); + auto fd = fa * fb + fc; + return rv_ftonvfp4_s(bit_cast(fd), sf, 0, nullptr); + } +}; + +template <> +struct muladd_t { + static int32_t eval(uint8_t a, uint8_t b, int32_t c) { + int32_t a_val = a & 0xF; + if (a & 0x8) { + a_val |= 0xFFFFFFF0; // sign extend + } + int32_t b_val = b & 0xF; + if (b & 0x8) { + b_val |= 0xFFFFFFF0; // sign extend + } + return a_val * b_val + c; + } +}; + +template <> +struct muladd_t { + static int32_t eval(uint8_t a, uint8_t b, int32_t c) { + int32_t a_val = a & 0xF; + int32_t b_val = b & 0xF; + return a_val * b_val + c; + } +}; + +template <> +struct muladd_t { + static int32_t eval(int8_t a, int8_t b, int32_t c) { + constexpr uint8_t sf_a = SCALE_FACTOR_E8M0_A; + constexpr uint8_t sf_b = SCALE_FACTOR_E8M0_B; + int32_t scale_exp_a = (int32_t)sf_a - 133; + float scale_factor_a = std::ldexp(1.0f, scale_exp_a); + int32_t scale_exp_b = (int32_t)sf_b - 133; + float scale_factor_b = std::ldexp(1.0f, scale_exp_b); + float product = (float)a * scale_factor_a * (float)b * scale_factor_b; + return (int32_t)product + c; + } +}; + +template +inline typename T::dtype generate_A_value() { + if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E8M0_A); + } else if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E4M3_A); + } else { + return Comparator::generate(); + } +} + +template +inline typename T::dtype generate_B_value() { + if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E8M0_B); + } else if constexpr (std::is_same_v) { + return Comparator::generate_with_scale(SCALE_FACTOR_E4M3_B); + } else { + return Comparator::generate(); + } +} + +/////////////////////////////////////////////////////////////////////////////// + +using cfg = vt::wmma_config_t; + +using itype_t = typename vt::ITYPE::dtype; +using otype_t = typename vt::OTYPE::dtype; + + +// static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { +// uint32_t subbytes = 8 / vt::ITYPE::bits; +// uint32_t KS = subbytes ? (K * subbytes) : K; +// for (uint32_t m = 0; m < M; ++m) { +// for (uint32_t n = 0; n < N; ++n) { +// otype_t sum(0); +// for (uint32_t k = 0; k < KS; ++k) { +// auto a = data_accessor_t::read(A, m * KS + k); +// auto b = data_accessor_t::read(B, k * N + n); +// sum = muladd_t::eval(a, b, sum); +// } +// data_accessor_t::write(C, m * N + n, sum); +// } +// } +// } + +// CPU reference matrix multiplication for sparse A case +static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { + uint32_t subbytes = 8 / vt::ITYPE::bits; + uint32_t KS = subbytes ? (K * subbytes) : K; + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t n = 0; n < N; ++n) { + otype_t sum(0); + for (uint32_t k = 0; k < (KS/2); ++k) { + uint32_t m_module = m % 4; + uint32_t m_block = m / 4; + auto a = data_accessor_t::read(A, m_module * KS + k + m_block * (KS/2)); + auto b = data_accessor_t::read(B, k * N + n); + sum = muladd_t::eval(a, b, sum); + } + data_accessor_t::write(C, m * N + n, sum); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +const char *kernel_file = "kernel.vxbin"; + +uint32_t xm = 32; +uint32_t xn = 32; +uint32_t xk = 32; + +vx_device_h device = nullptr; +vx_buffer_h A_buffer = nullptr; +vx_buffer_h B_buffer = nullptr; +vx_buffer_h C_buffer = nullptr; +vx_buffer_h krnl_buffer = nullptr; +vx_buffer_h args_buffer = nullptr; +kernel_arg_t kernel_arg = {}; + +std::string last_build_options; + +static void show_usage() { + std::cout << "Vortex Sgemm TCU Test." << std::endl; + std::cout << "Usage: [-m: m] [-n N] [-k: K] [-h: help]" << std::endl; +} + +static void parse_args(int argc, char **argv) { + int c; + while ((c = getopt(argc, argv, "m:n:k:i:o:hs")) != -1) { + switch (c) { + case 'm': + xm = atoi(optarg); + break; + case 'n': + xn = atoi(optarg); + break; + case 'k': + xk = atoi(optarg); + break; + case 'h': + show_usage(); + exit(0); + break; + default: + show_usage(); + exit(-1); + } + } +} + +void cleanup() { + if (device) { + vx_mem_free(A_buffer); + vx_mem_free(B_buffer); + vx_mem_free(C_buffer); + vx_mem_free(krnl_buffer); + vx_mem_free(args_buffer); + vx_dev_close(device); + } +} + + + + +int main(int argc, char *argv[]) { + // parse command arguments + parse_args(argc, argv); + + std::srand(50); + + // open device connection + std::cout << "open device connection" << std::endl; + RT_CHECK(vx_dev_open(&device)); + + uint64_t isa_flags; + RT_CHECK(vx_dev_caps(device, VX_CAPS_ISA_FLAGS, &isa_flags)); + bool has_ext = (isa_flags & VX_ISA_EXT_TCU) != 0; + if (!has_ext) { + std::cout << "TCU extension not supported!" << std::endl; + cleanup(); + return -1; + } + + uint64_t NT; + RT_CHECK(vx_dev_caps(device, VX_CAPS_NUM_THREADS, &NT)); + if (NT != NUM_THREADS) { + std::cout << "Error: device warp size (" << NT << ") must match NUM_THREADS=" << NUM_THREADS << "!" << std::endl; + return -1; + } + + uint32_t M = xm; + uint32_t N = xn; + uint32_t K = xk; + + if ((M % cfg::tileM) != 0) { + std::cout << "Error: M must be a multiple of tensor tileM!" << std::endl; + return -1; + } + + if ((N % cfg::tileN) != 0) { + std::cout << "Error: M must be a multiple of tensor tileN!" << std::endl; + return -1; + } + + if ((K % cfg::tileK) != 0) { + std::cout << "Error: M must be a multiple of tensor tileK!" << std::endl; + return -1; + } + + size_t sizeA = (M * K) / 2; + //size_t sizeA = M * K; + size_t sizeB = K * N; + size_t sizeC = M * N; + + std::cout << "input data type: " << vt::ITYPE::name << " (id=" << vt::ITYPE::id << ")" << std::endl; + std::cout << "output data type: " << vt::OTYPE::name << " (id=" << vt::OTYPE::id << ")" << std::endl; + std::cout << "WMMA Core Dimension: M=" << cfg::tcM << ", N=" << cfg::tcN << ", K=" << cfg::tcK << std::endl; + std::cout << "WMMA Tile Dimension: M=" << cfg::tileM << ", N=" << cfg::tileN << ", K=" << cfg::tileK << std::endl; + std::cout << "matrix A: " << M << "x" << K << std::endl; + std::cout << "matrix B: " << K << "x" << N << std::endl; + std::cout << "matrix C: " << M << "x" << N << std::endl; + + // set block size to warp size + kernel_arg.grid_dim[0] = N / cfg::tileN; + kernel_arg.grid_dim[1] = M / cfg::tileM; + kernel_arg.block_dim[0] = NT; // warp sizeb + kernel_arg.block_dim[1] = 1; + + // set matrix dimensions + kernel_arg.M = M; + kernel_arg.N = N; + kernel_arg.K = K; + + // allocate device memory + std::cout << "allocate device memory" << std::endl; + RT_CHECK(vx_mem_alloc(device, sizeA * sizeof(itype_t), VX_MEM_READ, &A_buffer)); + RT_CHECK(vx_mem_address(A_buffer, &kernel_arg.A_addr)); + RT_CHECK(vx_mem_alloc(device, sizeB * sizeof(itype_t), VX_MEM_READ, &B_buffer)); + RT_CHECK(vx_mem_address(B_buffer, &kernel_arg.B_addr)); + RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); + RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + + std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; + std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; + std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; + + // generate source data + std::vector h_A(sizeA); + std::vector h_B(sizeB); + for (uint32_t i = 0; i < sizeA; ++i) { // assume it is pruned and compressed already + //h_A[i] = Comparator::generate(); + h_A[i] = static_cast(i); + } + for (uint32_t i = 0; i < sizeB; ++i) { + //h_B[i] = Comparator::generate(); + h_B[i] = static_cast(1); + } + + // upload matrix A buffer + { + std::cout << "upload matrix A buffer" << std::endl; + RT_CHECK(vx_copy_to_dev(A_buffer, h_A.data(), 0, sizeA * sizeof(itype_t))); + } + + // upload matrix B buffer + { + std::cout << "upload matrix B buffer" << std::endl; + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) { + // sub-byte matrix B must be in col-major format + // we convert the 4-bit row-major to col-major here + std::vector h_B_col(sizeB); + convert_row_to_col_major_4bit(h_B_col.data(), N, 2 * K, (uint8_t*)h_B.data()); + RT_CHECK(vx_copy_to_dev(B_buffer, h_B_col.data(), 0, sizeB)); + } else { + RT_CHECK(vx_copy_to_dev(B_buffer, h_B.data(), 0, sizeB * sizeof(itype_t))); + } + } + + // upload program + std::cout << "upload program" << std::endl; + RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); + + // upload kernel argument + std::cout << "upload kernel argument" << std::endl; + RT_CHECK(vx_upload_bytes(device, &kernel_arg, sizeof(kernel_arg_t), &args_buffer)); + + auto time_start = std::chrono::high_resolution_clock::now(); + + // start device + std::cout << "start device" << std::endl; + RT_CHECK(vx_start(device, krnl_buffer, args_buffer)); + + // wait for completion + std::cout << "wait for completion" << std::endl; + RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT)); + + auto time_end = std::chrono::high_resolution_clock::now(); + double elapsed = std::chrono::duration_cast(time_end - time_start).count(); + printf("Elapsed time: %lg ms\n", elapsed); + + // download destination buffer + std::vector h_C(sizeC); + std::cout << "download destination buffer" << std::endl; + RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); + + // verify result + std::cout << "verify result" << std::endl; + int errors = 0; + { + std::vector h_ref(sizeC); + matmul_cpu(h_ref.data(), h_A.data(), h_B.data(), M, N, K); + + for (uint32_t i = 0; i < h_ref.size(); ++i) { + if (!Comparator::compare(h_C[i], h_ref[i], i, errors)) { + ++errors; + } + } + } + + // cleanup + std::cout << "cleanup" << std::endl; + cleanup(); + + if (errors != 0) { + std::cout << "Found " << std::dec << errors << " / " << sizeC << " errors!" << std::endl; + std::cout << "FAILED!" << std::endl; + return errors; + } + + std::cout << "PASSED!" << std::endl; + + return 0; +} \ No newline at end of file diff --git a/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py b/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py new file mode 100644 index 0000000000..fff759b077 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py @@ -0,0 +1,44 @@ +import numpy as np + +def prune_2_4_blockwise_with_mask(matrix): + """ + Perform 2:4 structured sparsity pruning on each row of the input matrix. + For each consecutive block of 4 elements, keep the two largest (by absolute value) + and zero out the rest. + Returns: + pruned: np.ndarray of same shape, with smaller elements zeroed out + mask: np.ndarray of bools, True where elements were kept + """ + pruned = matrix.copy() + mask = np.zeros_like(matrix, dtype=bool) + rows, cols = matrix.shape + + for i in range(rows): + for j in range(0, cols, 4): + block = pruned[i, j:j+4] + # Skip blocks that have fewer than 4 elements (at row end) + if block.shape[0] < 4: + continue + + abs_vals = np.abs(block) + sorted_idx = np.argsort(abs_vals) + top2_idx = sorted_idx[-2:] # Indices of the two largest absolute values + + block_mask = np.zeros_like(block, dtype=bool) + block_mask[top2_idx] = True + + #apply mask: zero out the smaller two elements in the block + pruned[i, j:j+4] = block * block_mask + mask[i, j:j+4] = block_mask + + return pruned, mask + +if __name__ == "__main__": + np.random.seed(42) + matrix = np.random.randn(8, 8) + + pruned_matrix, mask_matrix = prune_2_4_blockwise_with_mask(matrix) + + print("Original matrix:\n", matrix) + print("\nPruned matrix (2:4 structured sparse):\n", pruned_matrix) + print("\nMask matrix (True=kept, False=pruned):\n", mask_matrix) \ No newline at end of file diff --git a/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp b/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp new file mode 100644 index 0000000000..63cdba4875 --- /dev/null +++ b/tests/regression/sgemm_tcu_struct_sparse/tensor_generic.cpp @@ -0,0 +1,1076 @@ +#include +#include +#include +#include +#include +#include + +#define ENABLE_SPARSITY true +// Include random header only when sparsity is enabled +#ifdef ENABLE_SPARSITY +#include +#endif + +struct int4_t { + uint8_t data; +}; + +using float32_t = float; + +// ============================================================================ +// Configuration Macros +// ============================================================================ +#ifndef NUM_THREADS +#define NUM_THREADS 8 // Should be 32 for paper accuracy +#endif + +#ifndef XLENB +#define XLENB 4 +#endif + +#ifndef ITYPE +#define ITYPE int16_t +#endif + +#ifndef OTYPE +#define OTYPE int32_t +#endif + +#ifndef DPLEN +#define DPLEN 0 +#endif + +// ============================================================================ +// Debug Output Macros +// ============================================================================ +#ifdef NDEBUG +#define DBG_PRINT(fmt, ...) +#else +#define DBG_PRINT(fmt, ...) \ + do { \ + fprintf(stderr, fmt, __VA_ARGS__); \ + } while (0) +#endif + +#ifdef NDEBUG +class NullStream { +public: + template NullStream &operator<<(const T &) { return *this; } + NullStream &operator<<(std::ostream &(*)(std::ostream &)) { return *this; } + void flush() {} + static NullStream &instance() { + static NullStream null_stream; + return null_stream; + } +}; +#define dbg_out NullStream::instance() +#else +#define dbg_out std::cout +#endif + +template +struct DebugPrint; + +// ============================================================================ +// WMMA Configuration Template +// ============================================================================ +template +struct wmma_config_t { +private: + static constexpr uint32_t clog2(uint32_t x) { + return (x < 2) ? 0 : (1 + clog2(x / 2)); + } + static constexpr uint32_t tile_cap = NT * NR; + static constexpr uint32_t lg_tile_cap = clog2(tile_cap); + static constexpr uint32_t tile_en = lg_tile_cap / 2; + static constexpr uint32_t tile_em = lg_tile_cap - tile_en; + + static constexpr uint32_t block_cap = NT; + static constexpr uint32_t lg_block_cap = clog2(block_cap); + static constexpr uint32_t block_en = lg_block_cap / 2; + static constexpr uint32_t block_em = lg_block_cap - block_en; + +public: + static_assert(XB >= 0 && XB <= 8, "invalid XB value!"); + + static constexpr uint32_t i_ratio = XB / sizeof(It); + static constexpr uint32_t o_ratio = XB / sizeof(Ot); + static_assert(i_ratio * sizeof(It) == XB, "XB must be multiple of sizeof(It)"); + static_assert(o_ratio * sizeof(Ot) == XB, "XB must be multiple of sizeof(Ot)"); + + static constexpr uint32_t NumThreads = NT; + static constexpr uint32_t NumRegs = NR; + + static constexpr uint32_t xtileM = 1u << tile_em; + static constexpr uint32_t xtileN = 1u << tile_en; + static constexpr uint32_t xtileK = tile_cap / ((xtileM > xtileN) ? xtileM : xtileN); + + static constexpr uint32_t tcM = 1u << block_em; + static constexpr uint32_t tcN = 1u << block_en; + static constexpr uint32_t tcK = (DP != 0) ? DP : (block_cap / ((tcM > tcN) ? tcM : tcN)); + + static constexpr uint32_t m_steps = xtileM / tcM; + static constexpr uint32_t n_steps = xtileN / tcN; + static constexpr uint32_t k_steps = xtileK / tcK; + + static constexpr uint32_t a_block_size = tcM * tcK; + static constexpr uint32_t a_sub_blocks = block_cap / a_block_size; + static constexpr uint32_t a_sub_steps = m_steps / a_sub_blocks; + +#ifdef ENABLE_SPARSITY + // For 2:4 sparsity, B needs to provide both potential values + static constexpr uint32_t SPARSITY_RATIO = 2; + static constexpr uint32_t b_block_size = tcK * tcN * SPARSITY_RATIO; + static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; + static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; +#else + // Dense mode: standard B block configuration + static constexpr uint32_t b_block_size = tcK * tcN; + static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; + static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; +#endif + + static constexpr uint32_t NRA = (xtileM * xtileK) / NT; + static constexpr uint32_t NRB = (xtileN * xtileK) / NT; + static constexpr uint32_t NRC = (xtileM * xtileN) / NT; + + static constexpr uint32_t tileM = xtileM; + static constexpr uint32_t tileN = xtileN; + static constexpr uint32_t tileK = xtileK * i_ratio; + + static_assert(a_sub_steps != 0, "tcK is too small for tile A"); + static_assert(b_sub_steps != 0, "tcK is too small for tile B"); + + static_assert((xtileM * xtileK <= tile_cap), "xtileM * xtileK <= tile_cap"); + static_assert((xtileN * xtileK <= tile_cap), "xtileN * xtileK <= tile_cap"); + static_assert((xtileM * xtileN <= tile_cap), "xtileM * xtileN <= tile_cap"); + + static_assert((tcM * tcK <= block_cap), "tcM * tcK <= block_cap"); + static_assert((tcN * tcK <= block_cap), "tcN * tcK <= block_cap"); + static_assert((tcM * tcN <= block_cap), "tcM * tcN <= block_cap"); + + static_assert((xtileM % tcM) == 0, "M,m divisibility"); + static_assert((xtileN % tcN) == 0, "N,n divisibility"); + static_assert((xtileK % tcK) == 0, "K,k divisibility"); + + using vector_t = std::conditional_t<(XB == 1), uint8_t, + std::conditional_t<(XB == 2), uint16_t, + std::conditional_t<(XB == 4), uint32_t, uint64_t>>>; + using input_t = It; + using output_t = Ot; +}; + +// ============================================================================ +// Utility Types +// ============================================================================ +template +struct raw_unsigned { + static_assert( + sizeof(T) == 1 || sizeof(T) == 2 || + sizeof(T) == 4 || sizeof(T) == 8, + "raw_unsigned_t only supports types of size 1, 2, 4 or 8 bytes" + ); + + using type = std::conditional_t< + sizeof(T) == 1, uint8_t, + std::conditional_t< + sizeof(T) == 2, uint16_t, + std::conditional_t< + sizeof(T) == 4, uint32_t, + uint64_t + > + > + >; +}; + +template +using raw_unsigned_t = typename raw_unsigned::type; + +// ============================================================================ +// Pack Row Function +// ============================================================================ +template +D pack_row(const S *base, uint32_t ldm) { + static_assert(sizeof(D) % sizeof(S) == 0, "D must be a multiple of S"); + constexpr uint32_t count = sizeof(D) / sizeof(S); + using US = raw_unsigned_t; + D packed(0); + auto src = base; + for (uint32_t i = 0; i < count; ++i) { + US bits; + bits = *reinterpret_cast(src); + D elem = static_cast(bits); + packed |= (elem << (i * (8u * sizeof(S)))); + src += ldm; + } + return packed; +} + +// ============================================================================ +// Vector Register Type +// ============================================================================ +template +struct vector_t { +private: + std::array data_; + +public: + vector_t() = default; + + vector_t(T value) { + data_.fill(value); + } + + T* data() { + return data_.data(); + } + + const T* data() const { + return data_.data(); + } + + T& operator[](size_t idx) { + assert(idx < N); + return data_[idx]; + } + + const T& operator[](size_t idx) const { + assert(idx < N); + return data_[idx]; + } + + friend std::ostream &operator<<(std::ostream &os, const vector_t &v) { + os << std::hex << "{"; + for (size_t i = 0; i < N; ++i) { + if (i != 0) { + os << ", "; + } + os << "0x" << +v.data_[i]; + } + os << "}" << std::dec; + return os; + } +}; + +// ============================================================================ +// 2D Array Type +// ============================================================================ +template +struct array2d_t { +private: + std::array data_; + +public: + T* data() { + return data_.data(); + } + + const T* data() const { + return data_.data(); + } + + T &operator()(int row, int col) { + assert(row >= 0 && row < R); + assert(col >= 0 && col < C); + return data_[row * C + col]; + } + + const T &operator()(int row, int col) const { + assert(row >= 0 && row < R); + assert(col >= 0 && col < C); + return data_[row * C + col]; + } + + friend std::ostream &operator<<(std::ostream &os, const array2d_t &v) { + os << "{"; + for (size_t j = 0; j < R; ++j) { + if (j != 0) { + os << ", "; + } + os << "{"; + for (size_t i = 0; i < C; ++i) { + if (i != 0) { + os << ", "; + } + os << +v(j,i); + } + os << "}"; + } + os << "}"; + return os; + } +}; + +// ============================================================================ +// WMMA Implementation (Dense or Sparse based on ENABLE_SPARSITY) +// ============================================================================ +template +class WMMA { +private: + // Configuration constants + static constexpr uint32_t tileM = Config::tileM; + static constexpr uint32_t tileN = Config::tileN; + static constexpr uint32_t tileK = Config::tileK; + + static constexpr uint32_t tcM = Config::tcM; + static constexpr uint32_t tcN = Config::tcN; + static constexpr uint32_t tcK = Config::tcK; + + static constexpr uint32_t NT = Config::NumThreads; + static constexpr uint32_t NRA = Config::NRA; + static constexpr uint32_t NRB = Config::NRB; + static constexpr uint32_t NRC = Config::NRC; + + static constexpr uint32_t m_steps = Config::m_steps; + static constexpr uint32_t n_steps = Config::n_steps; + static constexpr uint32_t k_steps = Config::k_steps; + + static constexpr uint32_t a_block_size = Config::a_block_size; + static constexpr uint32_t a_sub_blocks = Config::a_sub_blocks; + static constexpr uint32_t a_sub_steps = Config::a_sub_steps; + + static constexpr uint32_t b_block_size = Config::b_block_size; + static constexpr uint32_t b_sub_blocks = Config::b_sub_blocks; + static constexpr uint32_t b_sub_steps = Config::b_sub_steps; + + static constexpr uint32_t i_ratio = Config::i_ratio; + static constexpr uint32_t o_ratio = Config::o_ratio; + +#ifdef ENABLE_SPARSITY + // Sparsity-specific constants + static constexpr uint32_t SPARSITY_N = 2; // 2 non-zero elements + static constexpr uint32_t SPARSITY_M = 4; // out of 4 elements (2:4 sparsity) + static constexpr uint32_t METADATA_LANES = Config::NumThreads / 4 / sizeof(typename Config::input_t); // Lanes 0,1 hold metadata for NT8, int8_t, 8Registers; for int16_t, NT=8, 4Registers, only lane 0 holds metadata + static constexpr uint32_t COMPRESSION_RATE = SPARSITY_M / SPARSITY_N; // 2x compression +#endif + + using Xt = typename Config::vector_t; + using It = typename Config::input_t; + using Ot = typename Config::output_t; + + using Vreg = vector_t; + + using FragA = array2d_t; + using FragB = array2d_t; + using FragC = array2d_t; + using FragD = array2d_t; + + // Matrix fragments + FragA fragA_; + FragB fragB_; + FragC fragC_; + FragD fragD_; + +#ifdef ENABLE_SPARSITY + // Sparsity-specific data structures + using FragA_meta = array2d_t; + + FragA fragA_compressed_; // Compressed matrix A (50% storage) + FragA_meta fragA_meta_; // Metadata: 1 = non-zero, 0 = pruned + static constexpr uint32_t META_ARRAY_SIZE = (tileM * tileK) / 32; //Total meta: tileM*tileK, each RISC-V register holds 32 bits + vector_t packed_bit_meta_; // Packed bitmap metadata. NT = 8, int8, 8REGS, MetaThreads = 2; int16, 4Regs, MetaThreads = 1 +#endif + + FragD fragRef_; + + uint32_t loop_iteration_count_; // Counter for total loop iterations + + // ======================================================================== + // Sparsity Helper Functions (only compiled when ENABLE_SPARSITY is defined) + // ======================================================================== +#ifdef ENABLE_SPARSITY + // Apply 2:4 structured pruning pattern + void apply_2_4_pruning(std::mt19937 &gen) { + std::vector masks = {1, 1, 0, 0}; // 2 ones, 2 zeros + + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK / SPARSITY_M; ++c) { + // Shuffle the mask for this group of 4 elements + std::shuffle(masks.begin(), masks.end(), gen); + + // Apply mask to each element in the group + for (uint32_t c_4 = 0; c_4 < SPARSITY_M; ++c_4) { + uint32_t col = c * SPARSITY_M + c_4; + if (masks[c_4] == 0) { + fragA_(r, col) = 0; + fragA_meta_(r, col) = 0; + } else { + fragA_meta_(r, col) = 1; + } + } + } + } + } + + // Compress matrix A by removing zeros + void compress_matrix_A() { + // Initialize compressed matrix to zero + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + fragA_compressed_(r, c) = 0; + } + } + + // Pack non-zero elements into compressed format + uint32_t comp_cnt = 0; + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + if (fragA_meta_(r, c) == 1) { + uint32_t comp_r = comp_cnt / (tileK / COMPRESSION_RATE); + uint32_t comp_c = comp_cnt % (tileK / COMPRESSION_RATE); + fragA_compressed_(comp_r, comp_c) = fragA_(r, c); + comp_cnt++; + } + } + } + } + + // Pack metadata into compact bitmap format + void pack_metadata_bitmap() { + constexpr uint32_t ELEMENTS_PER_ROW = tcK * i_ratio * COMPRESSION_RATE; + constexpr uint32_t ROWS_PER_CHUNK = tcM / COMPRESSION_RATE * sizeof(It); + + constexpr uint32_t k_steps_compressed = k_steps / COMPRESSION_RATE; + constexpr uint32_t num_chunks = COMPRESSION_RATE / sizeof(It); + + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t k = 0; k < k_steps / COMPRESSION_RATE; ++k) { + for (uint32_t chunk = 0; chunk < COMPRESSION_RATE / sizeof(It); ++chunk) { + uint32_t tmp_bit = 0; + + // Pack metadata for this chunk + for (uint32_t r_i = 0; r_i < ROWS_PER_CHUNK; ++r_i) { + for (uint32_t c_i = 0; c_i < ELEMENTS_PER_ROW; ++c_i) { + uint32_t row = r_i + chunk * ROWS_PER_CHUNK + m * tcM; + uint32_t col = c_i + k * ELEMENTS_PER_ROW; + + if (fragA_meta_(row, col) == 1) { + uint32_t bit_pos = 31 - (c_i + r_i * ELEMENTS_PER_ROW); + tmp_bit |= (1ULL << bit_pos); + } + } + } + uint32_t idx; + // + if(sizeof(It) == 1 || sizeof(It) == 2){ + idx = chunk + k * num_chunks + m * k_steps_compressed * num_chunks; + }else{ + static_assert(sizeof(It) == 1 || sizeof(It) == 2, "Only int8_t and int16_t are supported for sparsity"); + } + packed_bit_meta_[idx] = tmp_bit; + } + } + } + } + + // Extract bitmap for a specific row + uint16_t extract_row_metadata_int8_t(const Vreg &va_meta, uint32_t row_idx) const { + static_assert(sizeof(It) == 1, "int8_t extractor requires sizeof(It)==1"); + uint32_t meta_reg_idx = row_idx / COMPRESSION_RATE; + bool is_upper_half = (row_idx % COMPRESSION_RATE) == 0; + return is_upper_half ? + static_cast(va_meta[meta_reg_idx] >> 16) : + static_cast(va_meta[meta_reg_idx]); + } + + uint8_t extract_row_metadata_int16_t(const Vreg &va_meta, uint32_t row_idx) const { + constexpr uint32_t ELEMENTS_PER_ROW = COMPRESSION_RATE * (tcK * i_ratio); // = 8 + constexpr uint32_t ROWS_PER_CHUNK = 32 / ELEMENTS_PER_ROW; // = 4 + constexpr uint32_t ROW_MASK = 0xFF; //masks 8bits + + uint32_t meta_reg_idx = row_idx / ROWS_PER_CHUNK; // /4 + uint32_t which_part = row_idx % ROWS_PER_CHUNK; //%4 + uint32_t shift = 32 - (which_part + 1) * ELEMENTS_PER_ROW; + + uint32_t word = (va_meta[meta_reg_idx]); + return static_cast((word >> shift) & ROW_MASK); + + } + + // Gather B column elements based on A's sparsity pattern + void gather_sparse_B_column( + It *b_collected, + const Xt *b_col_0, + const Xt *b_col_1, + uint16_t a_row_meta) const { + //for ITYPE=int16_t, a_row_meta is uint8_t + //dbg_out << " [gather_sparse_B_column] a_row_meta=0x"<< std::hex << +a_row_meta << std::dec << "\n"; + + constexpr uint32_t TOTAL_ELEMENTS = tcK * i_ratio; + uint32_t collect_idx = 0; + + static_assert(sizeof(It) == 1 || sizeof(It) == 2, "Only int8_t and int16_t are supported for sparsity"); + uint32_t b_Mask = (uint32_t{1} << (8 * sizeof(It))) - 1; // 0xFF for int8, 0xFFFF for int16 + + // Gather from first half based on upper bits of metadata + for (uint32_t bit_idx = 0; bit_idx < TOTAL_ELEMENTS; ++bit_idx) { + uint32_t bit_pos = TOTAL_ELEMENTS * SPARSITY_N - bit_idx - 1; + if ((a_row_meta & (1 << bit_pos)) != 0) { + //dbg_out << " bit 1 at"<< " bit_idx=" << bit_idx << " bit_pos=" << bit_pos << "\n"; + uint32_t element_idx = bit_idx / i_ratio; + //dbg_out << " Gathering element_idx=" << element_idx << "\n"; + uint32_t byte_pos = (bit_idx % i_ratio) * 8 * sizeof(It); + //dbg_out << " byte_pos=" << byte_pos << "\n"; + b_collected[collect_idx++] = + static_cast((b_col_0[element_idx] >> byte_pos) & b_Mask); + //dbg_out << " " << +b_collected[collect_idx-1]; + } + } + + // Gather from second half based on lower bits of metadata + for (uint32_t bit_idx = 0; bit_idx < TOTAL_ELEMENTS; ++bit_idx) { + if (collect_idx >= TOTAL_ELEMENTS) break; + + uint32_t bit_pos = TOTAL_ELEMENTS - bit_idx - 1; + if ((a_row_meta & (1 << bit_pos)) != 0) { + //dbg_out << " bit 1 at"<< " bit_idx=" << bit_idx << " bit_pos=" << bit_pos << "\n"; + uint32_t element_idx = bit_idx / i_ratio; + //dbg_out << " Gathering element_idx=" << element_idx << "\n"; + uint32_t byte_pos = (bit_idx % i_ratio) * 8 * sizeof(It); + //dbg_out << " byte_pos=" << byte_pos << "\n"; + b_collected[collect_idx++] = + static_cast((b_col_1[element_idx] >> byte_pos) & b_Mask); + //dbg_out << " " << +b_collected[collect_idx-1]; + } + } + //dbg_out << "\n"; + } +#endif // ENABLE_SPARSITY + + // ======================================================================== + // Load/Store Operations (different implementations for dense/sparse) + // ======================================================================== + +#ifdef ENABLE_SPARSITY + // Sparse version of load_A + void load_A(vector_t &vR, uint32_t lane, uint32_t ldm, + const It *mdata, const vector_t &A_meta) { + uint32_t block_idx = lane / a_block_size; + uint32_t lane_in_block = lane % a_block_size; + uint32_t elem_row = lane_in_block / tcK; + uint32_t elem_col = lane_in_block % tcK; + + // Load compressed data into first half of registers + for (uint32_t r = 0; r < NRA / COMPRESSION_RATE; ++r) { + uint32_t block_m = (r / (k_steps / COMPRESSION_RATE)) * a_sub_blocks + block_idx; + uint32_t block_k = r % (k_steps / COMPRESSION_RATE); + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_k * tcK + elem_col; + auto base = mdata + row * ldm + col * i_ratio; + + assert(reinterpret_cast(base) % alignof(Xt) == 0 && + "Base pointer must be aligned"); + vR[r][lane] = *reinterpret_cast(base); + } + + // Load metadata into second half (only for metadata lanes) + if (lane < METADATA_LANES) { + for (uint32_t r = NRA / COMPRESSION_RATE; r < NRA; ++r) { + uint32_t meta_idx = (COMPRESSION_RATE * (r - NRA / COMPRESSION_RATE) + lane)/sizeof(It); + vR[r][lane] = A_meta.data()[meta_idx]; + /* dbg_out << "[load_A] lane=" << lane << " r=" << r + << " loads meta idx=" << meta_idx + << " value=0x" << std::hex << +vR[r][lane] << std::dec << "\n"; + */ + } + } else { + for (uint32_t r = NRA / COMPRESSION_RATE; r < NRA; ++r) { + vR[r][lane] = 0; + } + } + } +#else + // Dense version of load_A + void load_A(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / a_block_size; + uint32_t lane_in_block = lane % a_block_size; + uint32_t elem_row = lane_in_block / tcK; + uint32_t elem_col = lane_in_block % tcK; + //DBG_PRINT("[load_A] lane=%u block_idx=%u lane_in_block=%u elem=[%u,%u], src=%p-%p\n", + // lane, block_idx, lane_in_block, elem_row, elem_col, mdata, mdata + tileM * tileK); + + for (uint32_t r = 0; r < NRA; ++r) { + uint32_t block_m = (r / k_steps) * a_sub_blocks + block_idx; + uint32_t block_k = r % k_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_k * tcK + elem_col; + auto base = mdata + row * ldm + col * i_ratio; + + assert(reinterpret_cast(base) % alignof(Xt) == 0 && + "Base pointer must be aligned to sizeof(Xt)"); + vR[r][lane] = *reinterpret_cast(base); + //DBG_PRINT(" r=%u → block_m=%u block_k=%u → loads A[%u,%u] → %p → %u\n", + // r, block_m, block_k, row, col, base, vR[r][lane]); + } + } +#endif + +#ifdef ENABLE_SPARSITY + // Sparse version of load_B (loads 2x data for sparse B access) + void load_B(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / b_block_size; + uint32_t lane_in_block = lane % b_block_size; + uint32_t elem_col = lane_in_block / (tcK * COMPRESSION_RATE); + uint32_t elem_row = lane_in_block % (tcK * COMPRESSION_RATE); + + for (uint32_t r = 0; r < NRB; ++r) { + uint32_t block_k = r / b_sub_steps; + uint32_t block_n = (r % b_sub_steps) * b_sub_blocks + block_idx; + uint32_t row = block_k * tcK * COMPRESSION_RATE + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm * i_ratio + col; + + if constexpr (sizeof(Xt) == sizeof(It)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + vR[r][lane] = pack_row(base, ldm); + } + } + } +#else + // Dense version of load_B + void load_B(vector_t &vR, uint32_t lane, uint32_t ldm, const It *mdata) { + uint32_t block_idx = lane / b_block_size; + uint32_t lane_in_block = lane % b_block_size; + uint32_t elem_col = lane_in_block / tcK; + uint32_t elem_row = lane_in_block % tcK; + //DBG_PRINT("[load_B] lane=%u block_idx=%u lane_in_block=%u elem=[%u,%u], src=%p-%p\n", + // lane, block_idx, lane_in_block, elem_row, elem_col, mdata, mdata + tileK * tileN); + + for (uint32_t r = 0; r < NRB; ++r) { + uint32_t block_k = r / b_sub_steps; + uint32_t block_n = (r % b_sub_steps) * b_sub_blocks + block_idx; + uint32_t row = block_k * tcK + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm * i_ratio + col; + + if constexpr (sizeof(Xt) == sizeof(It)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + vR[r][lane] = pack_row(base, ldm); + } + //DBG_PRINT(" r=%u → block_k=%u block_n=%u → loads B[%u,%u] → %p → %u\n", + // r, block_k, block_n, row, col, base, vR[r][lane]); + } + } +#endif + + void load_C(vector_t &vR, uint32_t lane, uint32_t ldm, const Ot *mdata) { + uint32_t elem_row = lane / tcN; + uint32_t elem_col = lane % tcN; + // DBG_PRINT("[load_C] lane=%u elem=[%u,%u], src=%p-%p\n", + // lane, elem_row, elem_col, mdata, mdata + tileM * tileN); + + for (uint32_t r = 0; r < NRC; ++r) { + uint32_t block_m = r / n_steps; + uint32_t block_n = r % n_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm + col; + + if constexpr (sizeof(Xt) == sizeof(Ot)) { + vR[r][lane] = *reinterpret_cast(base); + } else { + Xt tmp(0); + *reinterpret_cast(&tmp) = *base; + vR[r][lane] = tmp; + } + // DBG_PRINT(" r=%u → block_m=%u block_n=%u → loads C[%u,%u] → %p → %u\n", + // r, block_m, block_n, row, col, base, vR[r][lane]); + } + } + + void store_D(Ot *mdata, uint32_t lane, uint32_t ldm, const vector_t &vR) { + uint32_t elem_row = lane / tcN; + uint32_t elem_col = lane % tcN; + + // DBG_PRINT("[store_D] lane=%u elem=[%u,%u], dst=%p-%p\n", + // lane, elem_row, elem_col, mdata, mdata + tileM * tileN); + + for (uint32_t r = 0; r < NRC; ++r) { + uint32_t block_m = r / n_steps; + uint32_t block_n = r % n_steps; + uint32_t row = block_m * tcM + elem_row; + uint32_t col = block_n * tcN + elem_col; + auto base = mdata + row * ldm + col; + + if constexpr (sizeof(Xt) == sizeof(Ot)) { + *reinterpret_cast(base) = vR[r][lane]; + } else { + Xt tmp(vR[r][lane]); + *base = *reinterpret_cast(&tmp); + } + // DBG_PRINT(" r=%u → block_m=%u block_n=%u → store C[%u,%u] → %p → %u\n", + // r, block_m, block_n, row, col, base , vR[r][lane]); + } + } + + // ======================================================================== + // Core Computation Operations + // ======================================================================== + + // Fused Element-wise Dot Product + Xt FEDP(const Xt *a_row, const Xt *b_col, Xt c_val) const { + Ot acc(*reinterpret_cast(&c_val)); + auto a = reinterpret_cast(a_row); + auto b = reinterpret_cast(b_col); + for (uint32_t z = 0; z < tcK * i_ratio; ++z) { + auto a_val = static_cast(a[z]); + auto b_val = static_cast(b[z]); + acc = a_val * b_val + acc; + } + Xt ret(0); + *reinterpret_cast(&ret) = acc; + return ret; + } + +#ifdef ENABLE_SPARSITY + // Sparse Matrix Multiply-Accumulate micro-operation + Vreg MMA(uint32_t m, uint32_t n, const Vreg &va, const Vreg &va_meta, + const Vreg &vb, const Vreg &vc) { + uint32_t a_off = (m % a_sub_blocks) * a_block_size; + uint32_t b_off = (n % b_sub_blocks) * b_block_size; + + Vreg vd; + It b_col_collected[tcK * i_ratio]; + + for (uint32_t i = 0; i < tcM; ++i) { + for (uint32_t j = 0; j < tcN; ++j) { + auto a_row = &va[a_off + i * tcK]; + auto b_col_0 = &vb[b_off + j * tcK * COMPRESSION_RATE]; + auto b_col_1 = &vb[b_off + j * tcK * COMPRESSION_RATE + tcK]; + auto c = vc[i * tcN + j]; + + // Extract metadata for this row + uint32_t a_row_meta; + if constexpr (sizeof(It) == 1){ + a_row_meta = extract_row_metadata_int8_t(va_meta, i); + }else if (sizeof(It) == 2){ + a_row_meta = extract_row_metadata_int16_t(va_meta, i); + } + // Gather sparse B elements based on A's metadata + gather_sparse_B_column(b_col_collected, b_col_0, b_col_1, a_row_meta); + + // Compute dot product + auto d = FEDP(a_row, reinterpret_cast(b_col_collected), c); + vd[i * tcN + j] = d; + } + } + + return vd; + } +#else + // Dense Matrix Multiply-Accumulate micro-operation + Vreg MMA(uint32_t m, uint32_t n, const Vreg &va, const Vreg &vb, const Vreg &vc) { + uint32_t a_off = (m % a_sub_blocks) * a_block_size; + uint32_t b_off = (n % b_sub_blocks) * b_block_size; + + Vreg vd; + for (uint32_t i = 0; i < tcM; ++i) { + for (uint32_t j = 0; j < tcN; ++j) { + auto a_row = &va[a_off + i * tcK]; + auto b_col = &vb[b_off + j * tcK]; + auto c = vc[i * tcN + j]; + auto d = FEDP(a_row, b_col, c); + vd[i * tcN + j] = d; + } + } + + return vd; + } +#endif + +#ifdef ENABLE_SPARSITY + // Sparse matrix multiply-add operation + FragD mmadd(const FragA &A, const vector_t &A_meta, + const FragB &B, const FragC &C) { + FragD D; + vector_t vA; + vector_t vB; + vector_t vC, vD; + + dbg_out << "A=" << A << "\n"; + dbg_out << "B=" << B << "\n"; + dbg_out << "C=" << C << "\n"; + + // Load fragments into vector registers + for (uint32_t lane = 0; lane < NT; ++lane) { + load_A(vA, lane, tileK, A.data(), A_meta); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_B(vB, lane, tileN, B.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_C(vC, lane, tileN, C.data()); + } + + // Execute micro-operations + for (uint32_t k = 0; k < k_steps / COMPRESSION_RATE; ++k) { + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t n = 0; n < n_steps; ++n) { + loop_iteration_count_++; // Count loop iterations + uint32_t idxA = (m / a_sub_blocks) * (k_steps / COMPRESSION_RATE) + k; + uint32_t idxA_meta = idxA + NRA / COMPRESSION_RATE; + uint32_t idxB = (k * n_steps + n) / b_sub_blocks; + uint32_t idxC = m * n_steps + n; + + auto &va = vA[idxA]; + auto &va_meta = vA[idxA_meta]; + auto &vb = vB[idxB]; + auto &vc = (k != 0) ? vD[idxC] : vC[idxC]; + + auto vd = MMA(m, n, va, va_meta, vb, vc); + vD[idxC] = vd; + } + } + } + + // Store results back to fragment + for (uint32_t lane = 0; lane < NT; ++lane) { + store_D(D.data(), lane, tileN, vD); + } + + dbg_out << "D=" << D << "\n"; + return D; + } +#else + // Dense matrix multiply-add operation + FragD mmadd(const FragA &A, const FragB &B, const FragC &C) { + FragD D; + vector_t vA; + vector_t vB; + vector_t vC, vD; + + dbg_out << "A=" << A << "\n"; + dbg_out << "B=" << B << "\n"; + dbg_out << "C=" << C << "\n"; + + // per-lane load + for (uint32_t lane = 0; lane < NT; ++lane) { + load_A(vA, lane, tileK, A.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_B(vB, lane, tileN, B.data()); + } + for (uint32_t lane = 0; lane < NT; ++lane) { + load_C(vC, lane, tileN, C.data()); + } + + for (uint32_t i = 0; i < NRA; ++i) { + dbg_out << "vA" << i << "=" << vA[i] << "\n"; + } + for (uint32_t i = 0; i < NRB; ++i) { + dbg_out << "vB" << i << "=" << vB[i] << "\n"; + } + for (uint32_t i = 0; i < NRC; ++i) { + dbg_out << "vC" << i << "=" << vC[i] << "\n"; + } + + // micro-ops + for (uint32_t k = 0; k < k_steps; ++k) { + for (uint32_t m = 0; m < m_steps; ++m) { + for (uint32_t n = 0; n < n_steps; ++n) { + loop_iteration_count_++; // Count loop iterations + uint32_t idxA = (m / a_sub_blocks) * k_steps + k; + uint32_t idxB = (k * n_steps + n) / b_sub_blocks; + uint32_t idxC = m * n_steps + n; + + auto &va = vA[idxA]; + auto &vb = vB[idxB]; + auto &vc = (k != 0) ? vD[idxC] : vC[idxC]; + + auto vd = MMA(m, n, va, vb, vc); + + // dbg_out << "[mmadd] m=" << m << " n=" << n << " k=" << k + // << " → idxA=" << idxA << " idxB=" << idxB << " idxC=" << idxC + // << " va=" << va << " vb=" << vb << " vc=" << vc << " vd=" << vd << "\n"; + + vD[idxC] = vd; + } + } + } + + dbg_out.flush(); + + for (uint32_t i = 0; i < NRC; ++i) { + dbg_out << "vD" << i << "=" << vD[i] << "\n"; + } + + // per-lane store + for (uint32_t lane = 0; lane < NT; ++lane) { + store_D(D.data(), lane, tileN, vD); + } + + dbg_out << "D=" << D << "\n"; + return D; + } +#endif + +public: + // ======================================================================== + // Public Interface + // ======================================================================== + + void init() { + int x = 0; + + // Initialize matrix A with sequential values + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileK; ++c) { + fragA_(r, c) = x++; + } + } + +#ifdef ENABLE_SPARSITY + // Apply 2:4 structured sparsity + std::random_device rd; + std::mt19937 gen(rd()); + apply_2_4_pruning(gen); + + // Compress sparse matrix A + compress_matrix_A(); + + // Pack metadata into bitmap format + pack_metadata_bitmap(); +#endif + + // Initialize matrix B with sequential values + for (uint32_t r = 0; r < tileK; ++r) { + for (uint32_t c = 0; c < tileN; ++c) { + fragB_(r, c) = x++; + } + } + + // Initialize matrix C to zero + for (uint32_t r = 0; r < tileM; ++r) { + for (uint32_t c = 0; c < tileN; ++c) { + fragC_(r, c) = 0; + } + } + + // Compute reference result + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + Ot sum(0); + for (uint32_t k = 0; k < tileK; ++k) { + auto a = static_cast(fragA_(row, k)); + auto b = static_cast(fragB_(k, col)); + sum = a * b + sum; + } + fragRef_(row, col) = sum + fragC_(row, col); + } + } + } + + float verify() const { + if constexpr (std::is_integral_v) { + int32_t err(0); + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + auto curr = static_cast(fragD_(row, col)); + auto ref = static_cast(fragRef_(row, col)); + auto diff = std::abs(curr - ref); + err = std::max(err, diff); + } + } + return static_cast(err); + } else { + float err(0); + for (uint32_t row = 0; row < tileM; ++row) { + for (uint32_t col = 0; col < tileN; ++col) { + auto curr = static_cast(fragD_(row, col)); + auto ref = static_cast(fragRef_(row, col)); + auto diff = std::fabs(curr - ref); + err = std::max(err, diff); + } + } + return err; + } + } + + uint32_t get_loop_count() const { + return loop_iteration_count_; + } + + void run() { + loop_iteration_count_ = 0; // Initialize counter +#ifdef ENABLE_SPARSITY + fragD_ = mmadd(fragA_compressed_, packed_bit_meta_, fragB_, fragC_); +#else + fragD_ = mmadd(fragA_, fragB_, fragC_); +#endif + } +}; + +// ============================================================================ +// Main Test Driver +// ============================================================================ +using cfg = wmma_config_t< + NUM_THREADS, + 8, + XLENB, + OTYPE, + ITYPE, + DPLEN>; + +int main() { + WMMA wmma; + +#ifdef ENABLE_SPARSITY + std::cout << "=== Sparse Tensor Core Configuration (2:4 Structured Sparsity) ===\n"; +#else + std::cout << "=== Dense Tensor Core Configuration ===\n"; +#endif + + std::cout + << "tileM = " << cfg::tileM << "\n" + << "tileN = " << cfg::tileN << "\n" + << "tileK = " << cfg::tileK << "\n" + << "tcM = " << cfg::tcM << "\n" + << "tcN = " << cfg::tcN << "\n" + << "tcK = " << cfg::tcK << "\n" + << "m_steps = " << cfg::m_steps << "\n" + << "n_steps = " << cfg::n_steps << "\n" + << "k_steps = " << cfg::k_steps << "\n" + << "a_block_size = " << cfg::a_block_size << "\n" + << "a_sub_blocks = " << cfg::a_sub_blocks << "\n" + << "a_sub_steps = " << cfg::a_sub_steps << "\n" + << "b_block_size = " << cfg::b_block_size << "\n" + << "b_sub_blocks = " << cfg::b_sub_blocks << "\n" + << "b_sub_steps = " << cfg::b_sub_steps << "\n" + << "NRA = " << cfg::NRA << "\n" + << "NRB = " << cfg::NRB << "\n" + << "NRC = " << cfg::NRC << "\n" + << "\n"; + + wmma.init(); + wmma.run(); + + auto err = wmma.verify(); + bool passed = (err < 1e-4f); + + std::cout << "Total loop iterations: " << wmma.get_loop_count() << "\n" + << "Max abs error: " << err << "\n" + << (passed ? "PASSED!" : "FAILED!") << '\n'; + + return passed ? 0 : 1; +} + +// ============================================================================ +// Build Instructions +// ============================================================================ +// Dense mode (default): +// g++ -std=c++17 -O2 tensor_generic.cpp -o a.out +// +// Sparse mode (2:4 structured sparsity): +// g++ -std=c++17 -O2 -DENABLE_SPARSITY tensor_generic.cpp -o a.out +// +// Debug builds: +// g++ -std=c++17 -g tensor_generic.cpp -o a.out +// g++ -std=c++17 -g -DENABLE_SPARSITY tensor_generic.cpp -o a.out From 93752d2339fc79628551734c0b3d234e3c9488b9 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Thu, 5 Feb 2026 22:54:15 -0800 Subject: [PATCH 02/22] Add sparse TCU support: VX_tcu_meta module and B-column mux --- hw/rtl/VX_gpu_pkg.sv | 3 +- hw/rtl/core/VX_decode.sv | 1 + hw/rtl/tcu/VX_tcu_core.sv | 55 ++++++++++++++++++++++++++++--- hw/rtl/tcu/VX_tcu_meta.sv | 68 +++++++++++++++++++++++++++++++++++++++ hw/rtl/tcu/VX_tcu_uops.sv | 1 + 5 files changed, 123 insertions(+), 5 deletions(-) create mode 100644 hw/rtl/tcu/VX_tcu_meta.sv diff --git a/hw/rtl/VX_gpu_pkg.sv b/hw/rtl/VX_gpu_pkg.sv index 3c4cd29eb6..5f17346d84 100644 --- a/hw/rtl/VX_gpu_pkg.sv +++ b/hw/rtl/VX_gpu_pkg.sv @@ -561,9 +561,10 @@ package VX_gpu_pkg; `ifdef EXT_TCU_ENABLE typedef struct packed { - logic [(INST_ARGS_BITS-16)-1:0] __padding; + logic [(INST_ARGS_BITS-20)-1:0] __padding; logic [3:0] fmt_d; logic [3:0] fmt_s; + logic [3:0] step_k; logic [3:0] step_n; logic [3:0] step_m; } tcu_args_t; diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index d635264ed8..268728b37f 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -556,6 +556,7 @@ module VX_decode import VX_gpu_pkg::*; #( op_args.tcu.fmt_d = rd[3:0]; op_args.tcu.step_m = '0; op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; `USED_FREG (rd); `USED_FREG (rs1); `USED_FREG (rs2); diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 09db141804..1ded91f281 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -62,11 +62,12 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [3:0] step_m = execute_if.data.op_args.tcu.step_m; wire [3:0] step_n = execute_if.data.op_args.tcu.step_n; + wire [3:0] step_k = execute_if.data.op_args.tcu.step_k; wire [3:0] fmt_s = execute_if.data.op_args.tcu.fmt_s; wire [3:0] fmt_d = execute_if.data.op_args.tcu.fmt_d; - `UNUSED_VAR ({step_m, step_n, fmt_s, fmt_d, execute_if.data}); + `UNUSED_VAR ({step_m, step_n, step_k, fmt_s, fmt_d, execute_if.data}); wire mdata_queue_full; @@ -117,16 +118,62 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; + // Metadata block from VX_tcu_meta (for 2:4 sparsity) + localparam I_RATIO = 4; // Elements per 32-bit word + localparam META_BLOCK_WIDTH = TCU_NT * 2 * I_RATIO; + localparam META_ROW_WIDTH = TCU_TC_K * 2 * I_RATIO; + localparam ELT_W = 32 / I_RATIO; // bits per element (8 for int8) + wire [META_BLOCK_WIDTH-1:0] vld_meta_block; + + VX_tcu_meta #( + .INSTANCE_ID (INSTANCE_ID), + .META_BLOCK_WIDTH(META_BLOCK_WIDTH) + ) tcu_meta ( + .clk (clk), + .reset (reset), + .step_m (step_m), + .step_k (step_k), + .vld_meta_block(vld_meta_block) + ); + for (genvar i = 0; i < TCU_TC_M; ++i) begin : g_i for (genvar j = 0; j < TCU_TC_N; ++j) begin : g_j - wire [TCU_TC_K-1:0][31:0] a_row, b_col; + wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_1, b_col_2; for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign - assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); - assign b_col[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); + assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); wire [TCU_MAX_INPUTS-1:0] vld_mask = '1; // TODO: should connect to input source + wire [META_ROW_WIDTH-1:0] vld_meta_row = vld_meta_block[META_ROW_WIDTH*i +: META_ROW_WIDTH]; + + // Sparse B-column mux: compress valid elements using 2:4 metadata + // Per K position: 2 groups of I_RATIO elements → I_RATIO valid elements + for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + + // Group lo: first 2 valid elements from b_col_1[k] + wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + + // Group hi: first 2 valid elements from b_col_2[k] + wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + + // Pack 4 valid elements into b_col[k] + assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; + end wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv new file mode 100644 index 0000000000..4a41713594 --- /dev/null +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -0,0 +1,68 @@ +// Copyright 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_define.vh" + +/* verilator lint_off UNUSEDSIGNAL */ + +module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( + parameter `STRING INSTANCE_ID = "", + parameter META_BLOCK_WIDTH = 64 // Default: TCU_NT * 2 * I_RATIO +) ( + input wire clk, + input wire reset, + + // Step indices (from VX_tcu_core) + input wire [3:0] step_m, + input wire [3:0] step_k, + + // Output (combinational) + output wire [META_BLOCK_WIDTH-1:0] vld_meta_block +); + `UNUSED_SPARAM (INSTANCE_ID) + `UNUSED_VAR (reset) + + // Local parameters + localparam HALF_K_STEPS = TCU_K_STEPS / 2; + localparam DEPTH = TCU_M_STEPS * HALF_K_STEPS; + localparam ADDRW = `CLOG2(DEPTH); + localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); // Bits needed for step_m index + localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); // Bits needed for step_k index (sparse) + + // Read address calculation using bit concatenation (no multiplication) + wire [ADDRW-1:0] read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + + // Metadata RAM with combinational read + VX_dp_ram #( + .DATAW (META_BLOCK_WIDTH), + .SIZE (DEPTH), + .WRENW (1), + .OUT_REG (0), // Combinational read: output same cycle as address + .RDW_MODE ("R"), + .INIT_ENABLE (1), + .INIT_VALUE ({(META_BLOCK_WIDTH/4){4'b1100}}) // 2:4 pattern: positions 2,3 valid in each group of 4 + ) meta_store ( + .clk (clk), + .reset (1'b0), // No reset needed for read-only + .read (1'b1), // Always enabled (combinational) + .write (1'b0), + .wren (1'b0), + .waddr ('0), + .wdata ('0), + .raddr (read_addr), + .rdata (vld_meta_block) + ); + +endmodule + +/* verilator lint_on UNUSEDSIGNAL */ diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index 3b06ea49bb..ada0ba42e4 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -92,6 +92,7 @@ module VX_tcu_uops import assign ibuf_out.op_args.tcu.fmt_d = ibuf_in.op_args.tcu.fmt_d; assign ibuf_out.op_args.tcu.step_m = 4'(m_index); assign ibuf_out.op_args.tcu.step_n = 4'(n_index); + assign ibuf_out.op_args.tcu.step_k = 4'(k_index); assign ibuf_out.wb = 1; assign ibuf_out.rd_xregs = ibuf_in.rd_xregs; assign ibuf_out.wr_xregs = ibuf_in.wr_xregs; From a580a6cd007973981adb8c65f2395fb6e64aeb1f Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 6 Feb 2026 10:11:17 -0800 Subject: [PATCH 03/22] Add sparse TCU support: B-column mux with VX_tcu_sel module --- hw/rtl/tcu/VX_tcu_core.sv | 36 ++++++++----------------- hw/rtl/tcu/VX_tcu_sel.sv | 57 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 25 deletions(-) create mode 100644 hw/rtl/tcu/VX_tcu_sel.sv diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 1ded91f281..c87f0c25d2 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -149,31 +149,17 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_MAX_INPUTS-1:0] vld_mask = '1; // TODO: should connect to input source wire [META_ROW_WIDTH-1:0] vld_meta_row = vld_meta_block[META_ROW_WIDTH*i +: META_ROW_WIDTH]; - // Sparse B-column mux: compress valid elements using 2:4 metadata - // Per K position: 2 groups of I_RATIO elements → I_RATIO valid elements - for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux - wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; - wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; - - // Group lo: first 2 valid elements from b_col_1[k] - wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : - grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : - b_col_1[k][2*ELT_W +: ELT_W]; - wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : - grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : - b_col_1[k][1*ELT_W +: ELT_W]; - - // Group hi: first 2 valid elements from b_col_2[k] - wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : - grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : - b_col_2[k][2*ELT_W +: ELT_W]; - wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : - grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : - b_col_2[k][1*ELT_W +: ELT_W]; - - // Pack 4 valid elements into b_col[k] - assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; - end + VX_tcu_sel #( + .INSTANCE_ID (INSTANCE_ID), + .META_ROW_WIDTH (META_ROW_WIDTH), + .I_RATIO (I_RATIO), + .ELT_W (ELT_W) + ) tcu_sel ( + .b_col_1 (b_col_1), + .b_col_2 (b_col_2), + .vld_meta_row (vld_meta_row), + .b_col (b_col) + ); wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; diff --git a/hw/rtl/tcu/VX_tcu_sel.sv b/hw/rtl/tcu/VX_tcu_sel.sv new file mode 100644 index 0000000000..1de6f7f0d4 --- /dev/null +++ b/hw/rtl/tcu/VX_tcu_sel.sv @@ -0,0 +1,57 @@ +// Copyright 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_define.vh" + +/* verilator lint_off UNUSEDSIGNAL */ + +module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( + parameter `STRING INSTANCE_ID = "", + parameter META_ROW_WIDTH = 16, + parameter I_RATIO = 4, + parameter ELT_W = 8 +) ( + input wire [TCU_TC_K-1:0][31:0] b_col_1, + input wire [TCU_TC_K-1:0][31:0] b_col_2, + input wire [META_ROW_WIDTH-1:0] vld_meta_row, + output wire [TCU_TC_K-1:0][31:0] b_col +); + `UNUSED_SPARAM (INSTANCE_ID); + + // Sparse B-column mux: compress valid elements using 2:4 metadata + // Per K position: 2 groups of I_RATIO elements -> I_RATIO valid elements + for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + + // Group lo: first 2 valid elements from b_col_1[k] + wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + + // Group hi: first 2 valid elements from b_col_2[k] + wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + + // Pack 4 valid elements into b_col[k] + assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; + end + +endmodule From aaa4a53cf187eb99c9424a0899eac8a3c6962cc7 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 6 Feb 2026 11:37:16 -0800 Subject: [PATCH 04/22] changed the cpu_ref function --- .../sgemm_tcu_struct_sparse/main.cpp | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index accb95a92c..3028811785 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -661,15 +661,23 @@ using otype_t = typename vt::OTYPE::dtype; static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = 8 / vt::ITYPE::bits; uint32_t KS = subbytes ? (K * subbytes) : K; + constexpr uint8_t META_MASK = 0b1100; for (uint32_t m = 0; m < M; ++m) { for (uint32_t n = 0; n < N; ++n) { otype_t sum(0); - for (uint32_t k = 0; k < (KS/2); ++k) { - uint32_t m_module = m % 4; - uint32_t m_block = m / 4; - auto a = data_accessor_t::read(A, m_module * KS + k + m_block * (KS/2)); - auto b = data_accessor_t::read(B, k * N + n); - sum = muladd_t::eval(a, b, sum); + uint32_t m_module = m % 4; + uint32_t m_block = m / 4; + uint32_t m_count = 0; + for (uint32_t k1 = 0; k1 < (KS/4); ++k1) { + for (uint32_t k2 = 0; k2 < 4; ++k2) { + uint32_t k = k1 * 4 + k2; + if (META_MASK & (1 << k2)) { + auto a = data_accessor_t::read(A, m_module * KS + m_block * (KS/2) + m_count ); + auto b = data_accessor_t::read(B, k * N + n); + sum = muladd_t::eval(a, b, sum); + m_count++; + } + } } data_accessor_t::write(C, m * N + n, sum); } From 5164075fde8a8e515c0938d9a0bcb5e843f1e171 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 6 Feb 2026 14:09:51 -0800 Subject: [PATCH 05/22] randomize the operands, fix the rtl index for b_col_1, b_col_2. --- hw/rtl/tcu/VX_tcu_core.sv | 4 ++-- tests/regression/sgemm_tcu_struct_sparse/main.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index c87f0c25d2..6d3d2579b1 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -141,8 +141,8 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_1, b_col_2; for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); - assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); - assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx]); + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 3028811785..02b6dd8c7b 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -831,12 +831,12 @@ int main(int argc, char *argv[]) { std::vector h_A(sizeA); std::vector h_B(sizeB); for (uint32_t i = 0; i < sizeA; ++i) { // assume it is pruned and compressed already - //h_A[i] = Comparator::generate(); - h_A[i] = static_cast(i); + h_A[i] = Comparator::generate(); + //h_A[i] = static_cast(i); } for (uint32_t i = 0; i < sizeB; ++i) { - //h_B[i] = Comparator::generate(); - h_B[i] = static_cast(1); + h_B[i] = Comparator::generate(); + //h_B[i] = static_cast(1); } // upload matrix A buffer From de3cd7a8f0b83f947a274000f51658c3522e3f4a Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 6 Feb 2026 23:30:09 -0800 Subject: [PATCH 06/22] fp16 fp32 printing. This code works for int8 and int32 --- hw/rtl/tcu/VX_tcu_meta.sv | 2 +- kernel/include/vx_tensor.h | 1 + .../sgemm_tcu_struct_sparse/kernel.cpp | 32 +++++++++++++++++++ .../sgemm_tcu_struct_sparse/main.cpp | 10 +++--- 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index 4a41713594..42de663168 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -50,7 +50,7 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( .OUT_REG (0), // Combinational read: output same cycle as address .RDW_MODE ("R"), .INIT_ENABLE (1), - .INIT_VALUE ({(META_BLOCK_WIDTH/4){4'b1100}}) // 2:4 pattern: positions 2,3 valid in each group of 4 + .INIT_VALUE ({(META_BLOCK_WIDTH/4){4'b1010}}) // 2:4 pattern: positions 2,3 valid in each group of 4 ) meta_store ( .clk (clk), .reset (1'b0), // No reset needed for read-only diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 9caae4d786..fd73a91849 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -16,6 +16,7 @@ #include #include #include +#include // for vx_printf namespace vortex { namespace tensor { diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 0b92470a27..46300fa062 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -5,6 +5,19 @@ namespace vt = vortex::tensor; using ctx = vt::wmma_context; +// Decode fp16 bit pattern to value*100 (2 decimal places) using integer math +static inline int32_t fp16_to_x100(uint16_t h) { + uint32_t e = (h >> 10) & 0x1F; + uint32_t m = h & 0x3FF; + if (e == 0) return 0; // zero / subnormal → 0 + // val = 2^(e-15) * (1024+m) / 1024 + // val*100 = (1024+m)*100 * 2^(e-25) + int32_t v = (int32_t)(1024 + m) * 100; + int s = (int)e - 25; + v = (s >= 0) ? (v << s) : (v >> (-s)); + return (h & 0x8000) ? -v : v; +} + void kernel_body(kernel_arg_t *__UNIFORM__ arg) { auto pA = reinterpret_cast(arg->A_addr); auto pB = reinterpret_cast(arg->B_addr); @@ -41,6 +54,25 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::load_matrix_sync(fragB, pTileB, N); } + // if (vx_thread_id() == 0 && blockIdx.x == 0 && blockIdx.y == 0) { + // for (uint32_t r = 0; r < 8; ++r) { + // uint32_t packed; + // asm volatile("fmv.x.w %0, %1" : "=r"(packed) : "f"(fragA.data[r])); + // int32_t lo = fp16_to_x100(packed & 0xFFFF); + // int32_t hi = fp16_to_x100((packed >> 16) & 0xFFFF); + // vx_printf("fragA[%d] | %d.%02d, %d.%02d\n", r, + // lo / 100, lo % 100, hi / 100, hi % 100); + // } + // for (uint32_t r = 0; r < 8; ++r) { + // uint32_t packed; + // asm volatile("fmv.x.w %0, %1" : "=r"(packed) : "f"(fragB.data[r])); + // int32_t lo = fp16_to_x100(packed & 0xFFFF); + // int32_t hi = fp16_to_x100((packed >> 16) & 0xFFFF); + // vx_printf("fragB[%d] | %d.%02d, %d.%02d\n", r, + // lo / 100, lo % 100, hi / 100, hi % 100); + // } + // } + // Matrix multiply-accumulate: c += a * b ctx::mma_sync(fragC, fragA, fragB, fragC); } diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 02b6dd8c7b..bee0d76734 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -661,7 +661,7 @@ using otype_t = typename vt::OTYPE::dtype; static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = 8 / vt::ITYPE::bits; uint32_t KS = subbytes ? (K * subbytes) : K; - constexpr uint8_t META_MASK = 0b1100; + constexpr uint8_t META_MASK = 0b1010; for (uint32_t m = 0; m < M; ++m) { for (uint32_t n = 0; n < N; ++n) { otype_t sum(0); @@ -831,12 +831,14 @@ int main(int argc, char *argv[]) { std::vector h_A(sizeA); std::vector h_B(sizeB); for (uint32_t i = 0; i < sizeA; ++i) { // assume it is pruned and compressed already - h_A[i] = Comparator::generate(); + //h_A[i] = Comparator::generate(); //h_A[i] = static_cast(i); + h_A[i] = rv_ftoh_s(bit_cast((float)i), 0, nullptr); } for (uint32_t i = 0; i < sizeB; ++i) { - h_B[i] = Comparator::generate(); - //h_B[i] = static_cast(1); + //h_B[i] = Comparator::generate(); + //h_B[i] = static_cast(i); + h_B[i] = rv_ftoh_s(bit_cast((float)i), 0, nullptr); } // upload matrix A buffer From 7a125dd11d25905b89a832ba4d8e023cd8db3298 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 6 Feb 2026 23:54:42 -0800 Subject: [PATCH 07/22] fp16/fp32 done by claude --- hw/rtl/tcu/VX_tcu_core.sv | 5 ++- hw/rtl/tcu/VX_tcu_sel.sv | 74 +++++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 6d3d2579b1..17d6c9697b 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -119,7 +119,10 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; // Metadata block from VX_tcu_meta (for 2:4 sparsity) - localparam I_RATIO = 4; // Elements per 32-bit word +`ifndef TCU_ITYPE_BITS +`define TCU_ITYPE_BITS 8 +`endif + localparam I_RATIO = 32 / `TCU_ITYPE_BITS; // Elements per 32-bit word localparam META_BLOCK_WIDTH = TCU_NT * 2 * I_RATIO; localparam META_ROW_WIDTH = TCU_TC_K * 2 * I_RATIO; localparam ELT_W = 32 / I_RATIO; // bits per element (8 for int8) diff --git a/hw/rtl/tcu/VX_tcu_sel.sv b/hw/rtl/tcu/VX_tcu_sel.sv index 1de6f7f0d4..0e6a805906 100644 --- a/hw/rtl/tcu/VX_tcu_sel.sv +++ b/hw/rtl/tcu/VX_tcu_sel.sv @@ -28,30 +28,58 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( ); `UNUSED_SPARAM (INSTANCE_ID); - // Sparse B-column mux: compress valid elements using 2:4 metadata - // Per K position: 2 groups of I_RATIO elements -> I_RATIO valid elements for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux - wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; - wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; - - // Group lo: first 2 valid elements from b_col_1[k] - wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : - grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : - b_col_1[k][2*ELT_W +: ELT_W]; - wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : - grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : - b_col_1[k][1*ELT_W +: ELT_W]; - - // Group hi: first 2 valid elements from b_col_2[k] - wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : - grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : - b_col_2[k][2*ELT_W +: ELT_W]; - wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : - grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : - b_col_2[k][1*ELT_W +: ELT_W]; - - // Pack 4 valid elements into b_col[k] - assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; + + if (I_RATIO == 4) begin : g_ratio4 + // int8: b_col_1 and b_col_2 are separate 4-element groups + // Select 2 valid from each group -> 4 output elements (4x8=32 bits) + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + + wire [ELT_W-1:0] lo_0 = grp_mask_lo[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + grp_mask_lo[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] lo_1 = grp_mask_lo[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + grp_mask_lo[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + + wire [ELT_W-1:0] hi_0 = grp_mask_hi[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + grp_mask_hi[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] hi_1 = grp_mask_hi[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + grp_mask_hi[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + + assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; + + end else if (I_RATIO == 2) begin : g_ratio2 + // fp16: b_col_1 and b_col_2 together form ONE 4-element group + // Select 2 valid from the combined group -> 2 output elements (2x16=32 bits) + wire [I_RATIO-1:0] mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + wire [3:0] grp_mask = {mask_hi, mask_lo}; + + // Pool of 4 fp16 elements across 2 registers + wire [ELT_W-1:0] elem0 = b_col_1[k][0 +: ELT_W]; + wire [ELT_W-1:0] elem1 = b_col_1[k][ELT_W +: ELT_W]; + wire [ELT_W-1:0] elem2 = b_col_2[k][0 +: ELT_W]; + wire [ELT_W-1:0] elem3 = b_col_2[k][ELT_W +: ELT_W]; + + // First valid (scan from LSB) + wire [ELT_W-1:0] sel_0 = grp_mask[0] ? elem0 : + grp_mask[1] ? elem1 : + grp_mask[2] ? elem2 : elem3; + + // Last valid (scan from MSB) + wire [ELT_W-1:0] sel_1 = grp_mask[3] ? elem3 : + grp_mask[2] ? elem2 : + grp_mask[1] ? elem1 : elem0; + + assign b_col[k] = {sel_1, sel_0}; + end + end endmodule + +/* verilator lint_on UNUSEDSIGNAL */ From 7630e3baee860c1bd89ea362a13134f88f78c98d Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sat, 7 Feb 2026 21:46:44 -0800 Subject: [PATCH 08/22] all pass with claude code --- hw/rtl/tcu/VX_tcu_sel.sv | 40 +++++++++++++++++++ hw/rtl/tcu/VX_tcu_uops.sv | 1 + kernel/include/vx_tensor.h | 29 ++++++++------ .../sgemm_tcu_struct_sparse/kernel.cpp | 15 +++---- .../sgemm_tcu_struct_sparse/main.cpp | 23 +++++------ 5 files changed, 76 insertions(+), 32 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_sel.sv b/hw/rtl/tcu/VX_tcu_sel.sv index 0e6a805906..40c6df06f9 100644 --- a/hw/rtl/tcu/VX_tcu_sel.sv +++ b/hw/rtl/tcu/VX_tcu_sel.sv @@ -76,6 +76,46 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( grp_mask[1] ? elem1 : elem0; assign b_col[k] = {sel_1, sel_0}; + + end else if (I_RATIO == 8) begin : g_ratio8 + // int4: each 32-bit register has 8 elements in 2 sub-groups of 4 + wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; + wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; + wire [3:0] sg0_mask = grp_mask_lo[3:0]; + wire [3:0] sg1_mask = grp_mask_lo[7:4]; + wire [3:0] sg2_mask = grp_mask_hi[3:0]; + wire [3:0] sg3_mask = grp_mask_hi[7:4]; + + // Sub-group 0: b_col_1 low half [elements 0-3] + wire [ELT_W-1:0] sg0_0 = sg0_mask[0] ? b_col_1[k][0*ELT_W +: ELT_W] : + sg0_mask[1] ? b_col_1[k][1*ELT_W +: ELT_W] : + b_col_1[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg0_1 = sg0_mask[3] ? b_col_1[k][3*ELT_W +: ELT_W] : + sg0_mask[2] ? b_col_1[k][2*ELT_W +: ELT_W] : + b_col_1[k][1*ELT_W +: ELT_W]; + // Sub-group 1: b_col_1 high half [elements 4-7] + wire [ELT_W-1:0] sg1_0 = sg1_mask[0] ? b_col_1[k][4*ELT_W +: ELT_W] : + sg1_mask[1] ? b_col_1[k][5*ELT_W +: ELT_W] : + b_col_1[k][6*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg1_1 = sg1_mask[3] ? b_col_1[k][7*ELT_W +: ELT_W] : + sg1_mask[2] ? b_col_1[k][6*ELT_W +: ELT_W] : + b_col_1[k][5*ELT_W +: ELT_W]; + // Sub-group 2: b_col_2 low half [elements 0-3] + wire [ELT_W-1:0] sg2_0 = sg2_mask[0] ? b_col_2[k][0*ELT_W +: ELT_W] : + sg2_mask[1] ? b_col_2[k][1*ELT_W +: ELT_W] : + b_col_2[k][2*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg2_1 = sg2_mask[3] ? b_col_2[k][3*ELT_W +: ELT_W] : + sg2_mask[2] ? b_col_2[k][2*ELT_W +: ELT_W] : + b_col_2[k][1*ELT_W +: ELT_W]; + // Sub-group 3: b_col_2 high half [elements 4-7] + wire [ELT_W-1:0] sg3_0 = sg3_mask[0] ? b_col_2[k][4*ELT_W +: ELT_W] : + sg3_mask[1] ? b_col_2[k][5*ELT_W +: ELT_W] : + b_col_2[k][6*ELT_W +: ELT_W]; + wire [ELT_W-1:0] sg3_1 = sg3_mask[3] ? b_col_2[k][7*ELT_W +: ELT_W] : + sg3_mask[2] ? b_col_2[k][6*ELT_W +: ELT_W] : + b_col_2[k][5*ELT_W +: ELT_W]; + + assign b_col[k] = {sg3_1, sg3_0, sg2_1, sg2_0, sg1_1, sg1_0, sg0_1, sg0_0}; end end diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index ada0ba42e4..e3031a1261 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -116,6 +116,7 @@ module VX_tcu_uops import done <= 0; end else begin if (~busy && start) begin + counter <= 0; busy <= 1; done <= (TCU_UOPS == 1); end else if (busy && next) begin diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index fd73a91849..31d3216d37 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -187,30 +187,35 @@ struct wmma_context { if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } + constexpr uint32_t sparse_k_steps = cfg::k_steps / 2; + constexpr uint32_t sparse_regs = cfg::m_steps * sparse_k_steps; auto base = reinterpret_cast(src) + block_row * ldm + block_col; detail::unroll_for([&](auto r) { - uint32_t block_m = r / cfg::k_steps; - uint32_t block_k = r % cfg::k_steps; + uint32_t block_m = r / sparse_k_steps; + uint32_t block_k = r % sparse_k_steps; uint32_t elem_row = block_m * m_stride; uint32_t elem_col = block_k * k_stride; if constexpr (src_layout == col_major) { static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); std::swap(elem_row, elem_col); - auto ptr = base + elem_row * ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); } } else { - // raw_major layout - auto ptr = base + elem_row * ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - //dst.data[r] = *reinterpret_cast(ptr); - if (r < 4) { + // row_major layout + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); dst.data[r] = *reinterpret_cast(ptr); } else { - // Zero for r=4,5,6,7 uint32_t zero = 0; dst.data[r] = *reinterpret_cast(&zero); } diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 46300fa062..9bb09a09a6 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -38,19 +38,20 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { // Initialize accumulator tile to zero ctx::fill_fragment(fragC, 0); - for (int i = 0; i < (K)/2; i += (ctx::tileK)/2) { - auto pTileA = pA + tile_row * K + i; + uint32_t stride_A = K / 2; + for (int i = 0; i < (int)(K / 2); i += (int)(ctx::tileK / 2)) { + auto pTileA = pA + tile_row * stride_A + i; - // Load A tile - ctx::load_matrix_sync(fragA, pTileA, K); + // Load A tile (compressed: stride = K/2) + ctx::load_matrix_sync(fragA, pTileA, stride_A); - // Load B tile + // Load B tile (full: uses 2*i to index into original K dimension) if constexpr (vt::ITYPE::bits < 8) { // For sub-byte matrix B must be in col-major format - auto pTileB = pB + tile_col * K + i; + auto pTileB = pB + tile_col * K + (2 * i); ctx::load_matrix_sync(fragB, pTileB, K); } else { - auto pTileB = pB + i * N + tile_col; + auto pTileB = pB + (2 * i) * N + tile_col; ctx::load_matrix_sync(fragB, pTileB, N); } diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index bee0d76734..4422038d8e 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -658,24 +658,25 @@ using otype_t = typename vt::OTYPE::dtype; // } // CPU reference matrix multiplication for sparse A case +// A is stored row-major compressed: M rows, each with K/2 non-zero elements +// Metadata is hardcoded 0b1010: positions 1,3 are kept in each group of 4 static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { - uint32_t subbytes = 8 / vt::ITYPE::bits; + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t stride_A = KS / 2; constexpr uint8_t META_MASK = 0b1010; for (uint32_t m = 0; m < M; ++m) { for (uint32_t n = 0; n < N; ++n) { otype_t sum(0); - uint32_t m_module = m % 4; - uint32_t m_block = m / 4; - uint32_t m_count = 0; - for (uint32_t k1 = 0; k1 < (KS/4); ++k1) { + uint32_t a_count = 0; + for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { for (uint32_t k2 = 0; k2 < 4; ++k2) { uint32_t k = k1 * 4 + k2; if (META_MASK & (1 << k2)) { - auto a = data_accessor_t::read(A, m_module * KS + m_block * (KS/2) + m_count ); + auto a = data_accessor_t::read(A, m * stride_A + a_count); auto b = data_accessor_t::read(B, k * N + n); sum = muladd_t::eval(a, b, sum); - m_count++; + a_count++; } } } @@ -831,14 +832,10 @@ int main(int argc, char *argv[]) { std::vector h_A(sizeA); std::vector h_B(sizeB); for (uint32_t i = 0; i < sizeA; ++i) { // assume it is pruned and compressed already - //h_A[i] = Comparator::generate(); - //h_A[i] = static_cast(i); - h_A[i] = rv_ftoh_s(bit_cast((float)i), 0, nullptr); + h_A[i] = generate_A_value(); } for (uint32_t i = 0; i < sizeB; ++i) { - //h_B[i] = Comparator::generate(); - //h_B[i] = static_cast(i); - h_B[i] = rv_ftoh_s(bit_cast((float)i), 0, nullptr); + h_B[i] = generate_B_value(); } // upload matrix A buffer From 815ee7c8763a89af351d9c1e94dfac192f67f9be Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sat, 7 Feb 2026 22:34:09 -0800 Subject: [PATCH 09/22] after all config passes, test the 0101/1010 two pattern sweap pass --- hw/rtl/tcu/VX_tcu_meta.sv | 33 ++++++++++++++----- .../sgemm_tcu_struct_sparse/main.cpp | 15 ++++++--- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index 42de663168..104c3f6cf1 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -30,7 +30,6 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( output wire [META_BLOCK_WIDTH-1:0] vld_meta_block ); `UNUSED_SPARAM (INSTANCE_ID) - `UNUSED_VAR (reset) // Local parameters localparam HALF_K_STEPS = TCU_K_STEPS / 2; @@ -42,23 +41,39 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( // Read address calculation using bit concatenation (no multiplication) wire [ADDRW-1:0] read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; - // Metadata RAM with combinational read + // Post-reset initialization: write alternating patterns into SRAM + // addr LSB = step_k[0]: even → 0101 (positions 0,2), odd → 1010 (positions 1,3) + reg [ADDRW:0] init_counter; + wire init_active = ~init_counter[ADDRW]; + wire [ADDRW-1:0] init_addr = init_counter[ADDRW-1:0]; + wire [META_BLOCK_WIDTH-1:0] init_data = init_addr[0] ? + {(META_BLOCK_WIDTH/4){4'b1010}} : + {(META_BLOCK_WIDTH/4){4'b0101}}; + + always_ff @(posedge clk) begin + if (reset) begin + init_counter <= 0; + end else if (init_active) begin + init_counter <= init_counter + 1; + end + end + + // Metadata SRAM with combinational read VX_dp_ram #( .DATAW (META_BLOCK_WIDTH), .SIZE (DEPTH), .WRENW (1), .OUT_REG (0), // Combinational read: output same cycle as address .RDW_MODE ("R"), - .INIT_ENABLE (1), - .INIT_VALUE ({(META_BLOCK_WIDTH/4){4'b1010}}) // 2:4 pattern: positions 2,3 valid in each group of 4 + .INIT_ENABLE (0) ) meta_store ( .clk (clk), - .reset (1'b0), // No reset needed for read-only + .reset (1'b0), .read (1'b1), // Always enabled (combinational) - .write (1'b0), - .wren (1'b0), - .waddr ('0), - .wdata ('0), + .write (init_active), + .wren (1'b1), + .waddr (init_addr), + .wdata (init_data), .raddr (read_addr), .rdata (vld_meta_block) ); diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 4422038d8e..bb4019d707 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -659,20 +659,27 @@ using otype_t = typename vt::OTYPE::dtype; // CPU reference matrix multiplication for sparse A case // A is stored row-major compressed: M rows, each with K/2 non-zero elements -// Metadata is hardcoded 0b1010: positions 1,3 are kept in each group of 4 +// Metadata alternates per step_k within each tile: +// step_k=0 (first half of tileK): 0101 — positions 0,2 kept +// step_k=1 (second half of tileK): 1010 — positions 1,3 kept static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; uint32_t stride_A = KS / 2; - constexpr uint8_t META_MASK = 0b1010; + // Scale tileK to element units (for sub-byte types, cfg::tileK is in register-element units) + uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; + uint32_t half_tile = tile_k_elem / 2; for (uint32_t m = 0; m < M; ++m) { for (uint32_t n = 0; n < N; ++n) { otype_t sum(0); uint32_t a_count = 0; for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { + uint32_t k_start = k1 * 4; + uint32_t pos_in_tile = k_start % tile_k_elem; + uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; for (uint32_t k2 = 0; k2 < 4; ++k2) { - uint32_t k = k1 * 4 + k2; - if (META_MASK & (1 << k2)) { + uint32_t k = k_start + k2; + if (meta_mask & (1 << k2)) { auto a = data_accessor_t::read(A, m * stride_A + a_count); auto b = data_accessor_t::read(B, k * N + n); sum = muladd_t::eval(a, b, sum); From eda31a179e3b15dda054add1c661aa75a7c54d98 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sun, 8 Feb 2026 05:41:22 -0800 Subject: [PATCH 10/22] new instruction working mma_struct_sparse_sync by claude code --- hw/rtl/VX_gpu_pkg.sv | 3 +- hw/rtl/core/VX_decode.sv | 13 + hw/rtl/core/VX_uop_sequencer.sv | 4 +- hw/rtl/tcu/VX_tcu_core.sv | 29 +- hw/rtl/tcu/VX_tcu_pkg.sv | 15 +- hw/rtl/tcu/VX_tcu_uops.sv | 40 ++- kernel/include/vx_tensor.h | 267 ++++++++++++++---- sim/common/tensor_cfg.h | 6 +- .../sgemm_tcu_struct_sparse/kernel.cpp | 14 +- 9 files changed, 302 insertions(+), 89 deletions(-) diff --git a/hw/rtl/VX_gpu_pkg.sv b/hw/rtl/VX_gpu_pkg.sv index 5f17346d84..641054f119 100644 --- a/hw/rtl/VX_gpu_pkg.sv +++ b/hw/rtl/VX_gpu_pkg.sv @@ -460,7 +460,8 @@ package VX_gpu_pkg; `ifdef EXT_TCU_ENABLE - localparam INST_TCU_WMMA = 4'h0; + localparam INST_TCU_WMMA = 4'h0; + localparam INST_TCU_WMMA_SP = 4'h1; localparam INST_TCU_BITS = 4; `endif diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 268728b37f..f83bd0d7e2 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -562,6 +562,19 @@ module VX_decode import VX_gpu_pkg::*; #( `USED_FREG (rs2); `USED_FREG (rs3); end + 3'h1: begin // WMMA_STRUCT_SPARSE_SYNC + ex_type = EX_TCU; + op_type = INST_OP_BITS'(INST_TCU_WMMA_SP); + op_args.tcu.fmt_s = rs1[3:0]; + op_args.tcu.fmt_d = rd[3:0]; + op_args.tcu.step_m = '0; + op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; + `USED_FREG (rd); + `USED_FREG (rs1); + `USED_FREG (rs2); + `USED_FREG (rs3); + end default:; endcase end diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index d3e38f4da7..2c80d53646 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -39,7 +39,9 @@ module VX_uop_sequencer import `ifdef EXT_TCU_ENABLE - assign is_uop_input = (input_if.data.ex_type == EX_TCU && input_if.data.op_type == INST_TCU_WMMA); + assign is_uop_input = (input_if.data.ex_type == EX_TCU + && (input_if.data.op_type == INST_TCU_WMMA + || input_if.data.op_type == INST_TCU_WMMA_SP)); VX_tcu_uops tcu_uops ( .clk (clk), diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 17d6c9697b..2e26295e24 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -56,9 +56,12 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam PIPE_LATENCY = FEDP_LATENCY + 1; localparam MDATA_QUEUE_DEPTH = 1 << $clog2(PIPE_LATENCY); - localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE); - localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE); - localparam OFF_W = $clog2(TCU_BLOCK_CAP); + localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE); + localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE); + localparam LG_B_BS_SP = $clog2(TCU_B_BLOCK_SIZE_SP); + localparam OFF_W = $clog2(TCU_BLOCK_CAP); + + wire is_sparse = (execute_if.data.op_type == INST_TCU_WMMA_SP); wire [3:0] step_m = execute_if.data.op_args.tcu.step_m; wire [3:0] step_n = execute_if.data.op_args.tcu.step_n; @@ -114,7 +117,9 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( ); wire [OFF_W-1:0] a_off = (OFF_W'(step_m) & OFF_W'(TCU_A_SUB_BLOCKS-1)) << LG_A_BS; - wire [OFF_W-1:0] b_off = (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; + wire [OFF_W-1:0] b_off = is_sparse + ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP + : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; @@ -141,11 +146,14 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( for (genvar i = 0; i < TCU_TC_M; ++i) begin : g_i for (genvar j = 0; j < TCU_TC_N; ++j) begin : g_j - wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_1, b_col_2; + wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_dense, b_col_sparse, b_col_1, b_col_2; for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign - assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); - assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); - assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); + assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); + // Dense: B registers packed with TCU_TC_K per column + assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); + // Sparse: B registers packed with TCU_TC_K*2 per column (2x for sparsity) + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); @@ -161,9 +169,12 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( .b_col_1 (b_col_1), .b_col_2 (b_col_2), .vld_meta_row (vld_meta_row), - .b_col (b_col) + .b_col (b_col_sparse) ); + // Select dense or sparse B column + assign b_col = is_sparse ? b_col_sparse : b_col_dense; + wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; wire [31:0] c_val_r; diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index 1a172e2674..081ed4e152 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -73,10 +73,14 @@ package VX_tcu_pkg; localparam TCU_A_BLOCK_SIZE = TCU_TC_M * TCU_TC_K; localparam TCU_A_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_A_BLOCK_SIZE; - // B micro-tiling - localparam TCU_B_BLOCK_SIZE = (TCU_TC_K * TCU_TC_N)*2; // sparsity 2601223 + // B micro-tiling (dense) + localparam TCU_B_BLOCK_SIZE = TCU_TC_K * TCU_TC_N; localparam TCU_B_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE; + // B micro-tiling (sparse 2:4) + localparam TCU_B_BLOCK_SIZE_SP = (TCU_TC_K * TCU_TC_N) * 2; + localparam TCU_B_SUB_BLOCKS_SP = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP; + // Register counts //localparam TCU_NRA = (TCU_TILE_M * TCU_TILE_K) / TCU_NT; localparam TCU_NRB = (TCU_TILE_N * TCU_TILE_K) / TCU_NT; @@ -177,6 +181,13 @@ package VX_tcu_pkg; trace_fmt(level, op_args.tcu.fmt_d); `TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n)); end + INST_TCU_WMMA_SP: begin + `TRACE(level, ("WMMA_SP.")); + trace_fmt(level, op_args.tcu.fmt_s); + `TRACE(level, (".")); + trace_fmt(level, op_args.tcu.fmt_d); + `TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n)); + end default: `TRACE(level, ("?")) endcase endtask diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index e3031a1261..132727f76e 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -33,8 +33,12 @@ module VX_tcu_uops import localparam LG_M = $clog2(TCU_M_STEPS); localparam LG_K = $clog2(TCU_K_STEPS); - localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); - localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); + localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); + localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); + localparam LG_B_SB_SP = $clog2(TCU_B_SUB_BLOCKS_SP); + + wire is_sparse_in = (ibuf_in.op_type == INST_TCU_WMMA_SP); + reg is_sparse; // uop counter reg [CTR_W-1:0] counter; @@ -61,13 +65,15 @@ module VX_tcu_uops import assign k_index = 0; end - // Register offsets - // wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); - // wire [CTR_W-1:0] rs2_offset = ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; - // wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); + // Register offsets — dense vs sparse formulas + wire [CTR_W-1:0] rs1_offset = is_sparse + ? ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index) + : ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); + + wire [CTR_W-1:0] rs2_offset = is_sparse + ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP + : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; - wire [CTR_W-1:0] rs1_offset = ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index); - wire [CTR_W-1:0] rs2_offset = ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); // Register calculations @@ -111,17 +117,21 @@ module VX_tcu_uops import always_ff @(posedge clk) begin if (reset) begin - counter <= 0; - busy <= 0; - done <= 0; + counter <= 0; + busy <= 0; + done <= 0; + is_sparse <= 0; end else begin if (~busy && start) begin - counter <= 0; - busy <= 1; - done <= (TCU_UOPS == 1); + counter <= 0; + busy <= 1; + is_sparse <= is_sparse_in; + done <= is_sparse_in ? (TCU_UOPS/2 == 1) : (TCU_UOPS == 1); end else if (busy && next) begin counter <= counter + ((TCU_UOPS > 1) ? 1 : 0); - done <= (counter == CTR_W'((TCU_UOPS/2)-2)); // sparsity 2601223 + done <= is_sparse + ? (counter == CTR_W'((TCU_UOPS/2)-2)) + : (counter == CTR_W'(TCU_UOPS-2)); busy <= ~done; end end diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 31d3216d37..058ee6da8f 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -173,7 +173,7 @@ struct wmma_context { }); } - template + template static __attribute__((always_inline)) void load_matrix_sync(Frag &dst, const void *src, size_t ldm) { uint32_t lane = vx_thread_id(); if constexpr (Frag::Use == matrix_a) { @@ -187,18 +187,53 @@ struct wmma_context { if constexpr (src_layout == col_major) { std::swap(block_row, block_col); } - constexpr uint32_t sparse_k_steps = cfg::k_steps / 2; - constexpr uint32_t sparse_regs = cfg::m_steps * sparse_k_steps; - auto base = reinterpret_cast(src) + block_row * ldm + block_col; - detail::unroll_for([&](auto r) { - uint32_t block_m = r / sparse_k_steps; - uint32_t block_k = r % sparse_k_steps; - uint32_t elem_row = block_m * m_stride; - uint32_t elem_col = block_k * k_stride; - if constexpr (src_layout == col_major) { - static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); - std::swap(elem_row, elem_col); - if constexpr (r < sparse_regs) { + if constexpr (sparse) { + // Sparse A load: only load half the K-steps (compressed A) + constexpr uint32_t sparse_k_steps = cfg::k_steps / 2; + constexpr uint32_t sparse_regs = cfg::m_steps * sparse_k_steps; + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_m = r / sparse_k_steps; + uint32_t block_k = r % sparse_k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } + } else { + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); + } + } else { + // row_major layout + if constexpr (r < sparse_regs) { + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); + } else { + uint32_t zero = 0; + dst.data[r] = *reinterpret_cast(&zero); + } + } + }); + } else { + // Dense A load: load all K-steps + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_m = r / cfg::k_steps; + uint32_t block_k = r % cfg::k_steps; + uint32_t elem_row = block_m * m_stride; + uint32_t elem_col = block_k * k_stride; + if constexpr (src_layout == col_major) { + static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a"); + std::swap(elem_row, elem_col); auto ptr = base + elem_row * ldm + elem_col; if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { dst.data[r] = *reinterpret_cast(ptr); @@ -206,54 +241,81 @@ struct wmma_context { dst.data[r] = input_acessor_t::pack_row(ptr, ldm); } } else { - uint32_t zero = 0; - dst.data[r] = *reinterpret_cast(&zero); - } - } else { - // row_major layout - if constexpr (r < sparse_regs) { auto ptr = base + elem_row * ldm + elem_col; assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); dst.data[r] = *reinterpret_cast(ptr); - } else { - uint32_t zero = 0; - dst.data[r] = *reinterpret_cast(&zero); } - } - }); - } else if constexpr (Frag::Use == matrix_b) { - // Load column-major matrix B - uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); - uint32_t lane_in_blk = (cfg::b_block_size == NT) ? lane : (lane % cfg::b_block_size); - uint32_t block_col = (lane_in_blk / ((cfg::tcK)*2)) + (block_idx * cfg::tcN); - uint32_t block_row = (lane_in_blk % ((cfg::tcK)*2)) * i_ratio; - uint32_t n_stride = cfg::b_sub_blocks * cfg::tcN; - uint32_t k_stride = ((cfg::tcK)*2) * i_ratio; - if constexpr (src_layout == col_major) { - std::swap(block_row, block_col); + }); } - auto base = reinterpret_cast(src) + block_row * ldm + block_col; - detail::unroll_for([&](auto r) { - uint32_t block_k = r / cfg::b_sub_steps; - uint32_t block_n = r % cfg::b_sub_steps; - uint32_t elem_row = block_k * k_stride; - uint32_t elem_col = block_n * n_stride; - if constexpr (src_layout == row_major) { - static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); - auto ptr = base + elem_row * ldm + elem_col; - if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { - dst.data[r] = *reinterpret_cast(ptr); + } else if constexpr (Frag::Use == matrix_b) { + if constexpr (sparse) { + // Sparse B load: uses 2x tcK for B block + constexpr uint32_t b_tcK = cfg::tcK * 2; + uint32_t block_idx = (cfg::b_block_size_sp == NT) ? 0 : (lane / cfg::b_block_size_sp); + uint32_t lane_in_blk = (cfg::b_block_size_sp == NT) ? lane : (lane % cfg::b_block_size_sp); + uint32_t block_col = (lane_in_blk / b_tcK) + (block_idx * cfg::tcN); + uint32_t block_row = (lane_in_blk % b_tcK) * i_ratio; + uint32_t n_stride = cfg::b_sub_blocks_sp * cfg::tcN; + uint32_t k_stride = b_tcK * i_ratio; + if constexpr (src_layout == col_major) { + std::swap(block_row, block_col); + } + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_k = r / cfg::b_sub_steps_sp; + uint32_t block_n = r % cfg::b_sub_steps_sp; + uint32_t elem_row = block_k * k_stride; + uint32_t elem_col = block_n * n_stride; + if constexpr (src_layout == row_major) { + static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } } else { - dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + // col_major layout + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); } - } else { - // col_major layout - std::swap(elem_row, elem_col); - auto ptr = base + elem_row * ldm + elem_col; - assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); - dst.data[r] = *reinterpret_cast(ptr); + }); + } else { + // Dense B load + uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size); + uint32_t lane_in_blk = (cfg::b_block_size == NT) ? lane : (lane % cfg::b_block_size); + uint32_t block_col = (lane_in_blk / cfg::tcK) + (block_idx * cfg::tcN); + uint32_t block_row = (lane_in_blk % cfg::tcK) * i_ratio; + uint32_t n_stride = cfg::b_sub_blocks * cfg::tcN; + uint32_t k_stride = cfg::tcK * i_ratio; + if constexpr (src_layout == col_major) { + std::swap(block_row, block_col); } - }); + auto base = reinterpret_cast(src) + block_row * ldm + block_col; + detail::unroll_for([&](auto r) { + uint32_t block_k = r / cfg::b_sub_steps; + uint32_t block_n = r % cfg::b_sub_steps; + uint32_t elem_row = block_k * k_stride; + uint32_t elem_col = block_n * n_stride; + if constexpr (src_layout == row_major) { + static_assert(input_is_subbyte == false, "row_major layout is not supported for sub-byte matrix_b"); + auto ptr = base + elem_row * ldm + elem_col; + if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) { + dst.data[r] = *reinterpret_cast(ptr); + } else { + dst.data[r] = input_acessor_t::pack_row(ptr, ldm); + } + } else { + // col_major layout + std::swap(elem_row, elem_col); + auto ptr = base + elem_row * ldm + elem_col; + assert(reinterpret_cast(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes"); + dst.data[r] = *reinterpret_cast(ptr); + } + }); + } } else { // Load accumulator matrix C uint32_t block_row = lane / cfg::tcN; @@ -413,6 +475,105 @@ struct wmma_context { } } + template + static __attribute__((always_inline)) void mma_struct_sparse_sync(FragD &fragD, const FragA &fragA, const FragB &fragB, const FragC &fragC) { + static_assert(FragA::Use == matrix_a, "A must be matrix_a"); + static_assert(FragB::Use == matrix_b, "B must be matrix_b"); + static_assert(FragC::Use == accumulator, "C must be accumulator"); + static_assert(FragD::Use == accumulator, "D must be accumulator"); + + // fragA: caller-saved registers (f0-f7) + register float fa0 __asm__("f0") = fragA.data[0]; + register float fa1 __asm__("f1") = fragA.data[1]; + register float fa2 __asm__("f2") = fragA.data[2]; + register float fa3 __asm__("f3") = fragA.data[3]; + register float fa4 __asm__("f4") = fragA.data[4]; + register float fa5 __asm__("f5") = fragA.data[5]; + register float fa6 __asm__("f6") = fragA.data[6]; + register float fa7 __asm__("f7") = fragA.data[7]; + + if constexpr (FragB::NR == 8) { + // fragB: caller-saved registers (f10-f17) + register float fb0 __asm__("f10") = fragB.data[0]; + register float fb1 __asm__("f11") = fragB.data[1]; + register float fb2 __asm__("f12") = fragB.data[2]; + register float fb3 __asm__("f13") = fragB.data[3]; + register float fb4 __asm__("f14") = fragB.data[4]; + register float fb5 __asm__("f15") = fragB.data[5]; + register float fb6 __asm__("f16") = fragB.data[6]; + register float fb7 __asm__("f17") = fragB.data[7]; + + // fragC: mix of caller-saved (f28-f31) and callee-saved (f18-f21) + register float fc0 __asm__("f24") = fragC.data[0]; + register float fc1 __asm__("f25") = fragC.data[1]; + register float fc2 __asm__("f26") = fragC.data[2]; + register float fc3 __asm__("f27") = fragC.data[3]; + register float fc4 __asm__("f28") = fragC.data[4]; + register float fc5 __asm__("f29") = fragC.data[5]; + register float fc6 __asm__("f30") = fragC.data[6]; + register float fc7 __asm__("f31") = fragC.data[7]; + + // Force outputs into accumulator registers + register float fd0 __asm__("f24"); + register float fd1 __asm__("f25"); + register float fd2 __asm__("f26"); + register float fd3 __asm__("f27"); + register float fd4 __asm__("f28"); + register float fd5 __asm__("f29"); + register float fd6 __asm__("f30"); + register float fd7 __asm__("f31"); + + __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" + : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) + : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), + "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7), + "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) + ); + + // Write results to fragD + fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; + } else { + static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); + // fragB: caller-saved registers (f28-f31) + register float fb0 __asm__("f28") = fragB.data[0]; + register float fb1 __asm__("f29") = fragB.data[1]; + register float fb2 __asm__("f30") = fragB.data[2]; + register float fb3 __asm__("f31") = fragB.data[3]; + + // fragC: mix of caller-saved (f10-f17) + register float fc0 __asm__("f10") = fragC.data[0]; + register float fc1 __asm__("f11") = fragC.data[1]; + register float fc2 __asm__("f12") = fragC.data[2]; + register float fc3 __asm__("f13") = fragC.data[3]; + register float fc4 __asm__("f14") = fragC.data[4]; + register float fc5 __asm__("f15") = fragC.data[5]; + register float fc6 __asm__("f16") = fragC.data[6]; + register float fc7 __asm__("f17") = fragC.data[7]; + + // Force outputs into accumulator registers + register float fd0 __asm__("f10"); + register float fd1 __asm__("f11"); + register float fd2 __asm__("f12"); + register float fd3 __asm__("f13"); + register float fd4 __asm__("f14"); + register float fd5 __asm__("f15"); + register float fd6 __asm__("f16"); + register float fd7 __asm__("f17"); + + __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" + : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) + : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), + "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), + "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) + ); + + // Write results to fragD + fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; + } + } + template static __attribute__((always_inline)) void mma_sp_sync( FragD &fragD, diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index f46dbf4989..1cee6946ce 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -191,10 +191,14 @@ struct wmma_config_t { static constexpr uint32_t a_sub_blocks = block_cap / a_block_size; // number of A micro-tiles per register static constexpr uint32_t a_sub_steps = m_steps / a_sub_blocks; // number of A sub-steps per register - static constexpr uint32_t b_block_size = (tcK * tcN)*2; // size of B micro-tile + static constexpr uint32_t b_block_size = tcK * tcN; // size of B micro-tile (dense) static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; // number of B micro-tiles per register static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register + static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // size of B micro-tile (sparse 2:4) + static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; // number of B micro-tiles per register (sparse) + static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; // number of B sub-steps per register (sparse) + static constexpr uint32_t NRA = (xtileM * xtileK) / NT; // Number of A registers static constexpr uint32_t NRB = (xtileN * xtileK) / NT; // Number of B registers static constexpr uint32_t NRC = (xtileM * xtileN) / NT; // Number of C registers diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 9bb09a09a6..48a50799f4 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -42,17 +42,17 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { for (int i = 0; i < (int)(K / 2); i += (int)(ctx::tileK / 2)) { auto pTileA = pA + tile_row * stride_A + i; - // Load A tile (compressed: stride = K/2) - ctx::load_matrix_sync(fragA, pTileA, stride_A); + // Load A tile (compressed: stride = K/2, sparse=true) + ctx::load_matrix_sync(fragA, pTileA, stride_A); - // Load B tile (full: uses 2*i to index into original K dimension) + // Load B tile (full: uses 2*i to index into original K dimension, sparse=true) if constexpr (vt::ITYPE::bits < 8) { // For sub-byte matrix B must be in col-major format auto pTileB = pB + tile_col * K + (2 * i); - ctx::load_matrix_sync(fragB, pTileB, K); + ctx::load_matrix_sync(fragB, pTileB, K); } else { auto pTileB = pB + (2 * i) * N + tile_col; - ctx::load_matrix_sync(fragB, pTileB, N); + ctx::load_matrix_sync(fragB, pTileB, N); } // if (vx_thread_id() == 0 && blockIdx.x == 0 && blockIdx.y == 0) { @@ -74,8 +74,8 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { // } // } - // Matrix multiply-accumulate: c += a * b - ctx::mma_sync(fragC, fragA, fragB, fragC); + // Matrix multiply-accumulate: c += a * b (sparse instruction, funct3=1) + ctx::mma_struct_sparse_sync(fragC, fragA, fragB, fragC); } // Store the computed C tile From bfcf24b1d5304672ac33f63202acea443ab66f19 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sun, 8 Feb 2026 06:09:35 -0800 Subject: [PATCH 11/22] prune and compress with fixed mast, fix matmul_cpu --- .../sgemm_tcu_struct_sparse/main.cpp | 109 +++++++++++------- 1 file changed, 65 insertions(+), 44 deletions(-) diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index bb4019d707..4c2757b5a8 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -641,57 +641,73 @@ using itype_t = typename vt::ITYPE::dtype; using otype_t = typename vt::OTYPE::dtype; -// static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { -// uint32_t subbytes = 8 / vt::ITYPE::bits; -// uint32_t KS = subbytes ? (K * subbytes) : K; -// for (uint32_t m = 0; m < M; ++m) { -// for (uint32_t n = 0; n < N; ++n) { -// otype_t sum(0); -// for (uint32_t k = 0; k < KS; ++k) { -// auto a = data_accessor_t::read(A, m * KS + k); -// auto b = data_accessor_t::read(B, k * N + n); -// sum = muladd_t::eval(a, b, sum); -// } -// data_accessor_t::write(C, m * N + n, sum); -// } -// } -// } - -// CPU reference matrix multiplication for sparse A case -// A is stored row-major compressed: M rows, each with K/2 non-zero elements -// Metadata alternates per step_k within each tile: -// step_k=0 (first half of tileK): 0101 — positions 0,2 kept -// step_k=1 (second half of tileK): 1010 — positions 1,3 kept +// Dense CPU reference matmul. Works for sparse case too because pruned A +// has zeros at masked positions, so A[m][k]*B[k][n] = 0 naturally. static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; - uint32_t stride_A = KS / 2; - // Scale tileK to element units (for sub-byte types, cfg::tileK is in register-element units) - uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; - uint32_t half_tile = tile_k_elem / 2; for (uint32_t m = 0; m < M; ++m) { for (uint32_t n = 0; n < N; ++n) { otype_t sum(0); - uint32_t a_count = 0; - for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { - uint32_t k_start = k1 * 4; - uint32_t pos_in_tile = k_start % tile_k_elem; - uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; - for (uint32_t k2 = 0; k2 < 4; ++k2) { - uint32_t k = k_start + k2; - if (meta_mask & (1 << k2)) { - auto a = data_accessor_t::read(A, m * stride_A + a_count); - auto b = data_accessor_t::read(B, k * N + n); - sum = muladd_t::eval(a, b, sum); - a_count++; - } - } + for (uint32_t k = 0; k < KS; ++k) { + auto a = data_accessor_t::read(A, m * KS + k); + auto b = data_accessor_t::read(B, k * N + n); + sum = muladd_t::eval(a, b, sum); } data_accessor_t::write(C, m * N + n, sum); } } } +// In-place: zero out elements in full A (M × K) that are NOT selected by +// the fixed 0101/1010 alternating metadata mask. +static void prune_fixed_mask(itype_t *A, uint32_t M, uint32_t K) { + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; + uint32_t half_tile = tile_k_elem / 2; + + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { + uint32_t k_start = k1 * 4; + uint32_t pos_in_tile = k_start % tile_k_elem; + uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + for (uint32_t k2 = 0; k2 < 4; ++k2) { + if (!(meta_mask & (1 << k2))) { + data_accessor_t::write(A, m * KS + k_start + k2, 0); + } + } + } + } +} + +// Extract mask-selected (non-zero) positions from pruned A (M × K) into +// compressed output (M × K/2). +static void compress_fixed_mask(itype_t *compressed, const itype_t *pruned_A, + uint32_t M, uint32_t K) { + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t stride_comp = KS / 2; + uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; + uint32_t half_tile = tile_k_elem / 2; + + for (uint32_t m = 0; m < M; ++m) { + uint32_t a_out = 0; + for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { + uint32_t k_start = k1 * 4; + uint32_t pos_in_tile = k_start % tile_k_elem; + uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + for (uint32_t k2 = 0; k2 < 4; ++k2) { + if (meta_mask & (1 << k2)) { + auto val = data_accessor_t::read(pruned_A, m * KS + k_start + k2); + data_accessor_t::write(compressed, m * stride_comp + a_out, val); + a_out++; + } + } + } + } +} + /////////////////////////////////////////////////////////////////////////////// const char *kernel_file = "kernel.vxbin"; @@ -798,8 +814,8 @@ int main(int argc, char *argv[]) { return -1; } + size_t sizeA_full = M * K; size_t sizeA = (M * K) / 2; - //size_t sizeA = M * K; size_t sizeB = K * N; size_t sizeC = M * N; @@ -836,11 +852,16 @@ int main(int argc, char *argv[]) { std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; // generate source data + // Generate full matrix A (M × K), prune in-place, then compress to M × K/2 + std::vector h_A_full(sizeA_full); + for (uint32_t i = 0; i < sizeA_full; ++i) { + h_A_full[i] = generate_A_value(); + } + prune_fixed_mask(h_A_full.data(), M, K); std::vector h_A(sizeA); + compress_fixed_mask(h_A.data(), h_A_full.data(), M, K); + std::vector h_B(sizeB); - for (uint32_t i = 0; i < sizeA; ++i) { // assume it is pruned and compressed already - h_A[i] = generate_A_value(); - } for (uint32_t i = 0; i < sizeB; ++i) { h_B[i] = generate_B_value(); } @@ -899,7 +920,7 @@ int main(int argc, char *argv[]) { int errors = 0; { std::vector h_ref(sizeC); - matmul_cpu(h_ref.data(), h_A.data(), h_B.data(), M, N, K); + matmul_cpu(h_ref.data(), h_A_full.data(), h_B.data(), M, N, K); for (uint32_t i = 0; i < h_ref.size(); ++i) { if (!Comparator::compare(h_C[i], h_ref[i], i, errors)) { From 833bacf9dce00460b2f463dbc59179ca4594d5ac Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sun, 8 Feb 2026 07:05:07 -0800 Subject: [PATCH 12/22] code minimization with same functionality --- hw/rtl/core/VX_decode.sv | 43 ++---- hw/rtl/tcu/VX_tcu_core.sv | 4 +- hw/rtl/tcu/VX_tcu_meta.sv | 7 +- hw/rtl/tcu/VX_tcu_pkg.sv | 10 +- hw/rtl/tcu/VX_tcu_sel.sv | 16 +- hw/rtl/tcu/VX_tcu_uops.sv | 2 +- kernel/include/vx_tensor.h | 139 +----------------- sim/common/tensor_cfg.h | 6 +- .../sgemm_tcu_struct_sparse/kernel.cpp | 42 +----- .../sgemm_tcu_struct_sparse/main.cpp | 10 +- .../sgemm_tcu_struct_sparse/sparse_test.py | 44 ------ 11 files changed, 39 insertions(+), 284 deletions(-) delete mode 100644 tests/regression/sgemm_tcu_struct_sparse/sparse_test.py diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index f83bd0d7e2..4a876a1464 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -548,35 +548,20 @@ module VX_decode import VX_gpu_pkg::*; #( end `ifdef EXT_TCU_ENABLE 7'h02: begin - case (funct3) - 3'h0: begin // WMMA_SYNC - ex_type = EX_TCU; - op_type = INST_OP_BITS'(INST_TCU_WMMA); - op_args.tcu.fmt_s = rs1[3:0]; - op_args.tcu.fmt_d = rd[3:0]; - op_args.tcu.step_m = '0; - op_args.tcu.step_n = '0; - op_args.tcu.step_k = '0; - `USED_FREG (rd); - `USED_FREG (rs1); - `USED_FREG (rs2); - `USED_FREG (rs3); - end - 3'h1: begin // WMMA_STRUCT_SPARSE_SYNC - ex_type = EX_TCU; - op_type = INST_OP_BITS'(INST_TCU_WMMA_SP); - op_args.tcu.fmt_s = rs1[3:0]; - op_args.tcu.fmt_d = rd[3:0]; - op_args.tcu.step_m = '0; - op_args.tcu.step_n = '0; - op_args.tcu.step_k = '0; - `USED_FREG (rd); - `USED_FREG (rs1); - `USED_FREG (rs2); - `USED_FREG (rs3); - end - default:; - endcase + if (funct3 == 3'h0 || funct3 == 3'h1) begin + ex_type = EX_TCU; + op_type = funct3[0] ? INST_OP_BITS'(INST_TCU_WMMA_SP) + : INST_OP_BITS'(INST_TCU_WMMA); + op_args.tcu.fmt_s = rs1[3:0]; + op_args.tcu.fmt_d = rd[3:0]; + op_args.tcu.step_m = '0; + op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; + `USED_FREG (rd); + `USED_FREG (rs1); + `USED_FREG (rs2); + `USED_FREG (rs3); + end end `endif default:; diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 2e26295e24..f1d8be8635 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -123,7 +123,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; - // Metadata block from VX_tcu_meta (for 2:4 sparsity) + // 2:4 sparsity metadata `ifndef TCU_ITYPE_BITS `define TCU_ITYPE_BITS 8 `endif @@ -149,9 +149,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_dense, b_col_sparse, b_col_1, b_col_2; for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); - // Dense: B registers packed with TCU_TC_K per column assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); - // Sparse: B registers packed with TCU_TC_K*2 per column (2x for sparsity) assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index 104c3f6cf1..b8f83e5ee5 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -38,11 +38,10 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); // Bits needed for step_m index localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); // Bits needed for step_k index (sparse) - // Read address calculation using bit concatenation (no multiplication) + // Read address: {step_m, step_k} wire [ADDRW-1:0] read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; - // Post-reset initialization: write alternating patterns into SRAM - // addr LSB = step_k[0]: even → 0101 (positions 0,2), odd → 1010 (positions 1,3) + // Post-reset init: even addr → 0101, odd addr → 1010 reg [ADDRW:0] init_counter; wire init_active = ~init_counter[ADDRW]; wire [ADDRW-1:0] init_addr = init_counter[ADDRW-1:0]; @@ -58,7 +57,7 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( end end - // Metadata SRAM with combinational read + // Metadata SRAM (combinational read) VX_dp_ram #( .DATAW (META_BLOCK_WIDTH), .SIZE (DEPTH), diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index 081ed4e152..b119b06b0d 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -174,15 +174,9 @@ package VX_tcu_pkg; input op_args_t op_args ); case (INST_TCU_BITS'(op_type)) - INST_TCU_WMMA: begin - `TRACE(level, ("WMMA.")); - trace_fmt(level, op_args.tcu.fmt_s); - `TRACE(level, (".")); - trace_fmt(level, op_args.tcu.fmt_d); - `TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n)); - end + INST_TCU_WMMA, INST_TCU_WMMA_SP: begin - `TRACE(level, ("WMMA_SP.")); + `TRACE(level, (INST_TCU_BITS'(op_type) == INST_TCU_WMMA_SP ? "WMMA_SP." : "WMMA.")); trace_fmt(level, op_args.tcu.fmt_s); `TRACE(level, (".")); trace_fmt(level, op_args.tcu.fmt_d); diff --git a/hw/rtl/tcu/VX_tcu_sel.sv b/hw/rtl/tcu/VX_tcu_sel.sv index 40c6df06f9..b3b6815d32 100644 --- a/hw/rtl/tcu/VX_tcu_sel.sv +++ b/hw/rtl/tcu/VX_tcu_sel.sv @@ -31,8 +31,7 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( for (genvar k = 0; k < TCU_TC_K; ++k) begin : g_bmux if (I_RATIO == 4) begin : g_ratio4 - // int8: b_col_1 and b_col_2 are separate 4-element groups - // Select 2 valid from each group -> 4 output elements (4x8=32 bits) + // int8: select 2 valid from each 4-element group wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; @@ -53,24 +52,21 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( assign b_col[k] = {hi_1, hi_0, lo_1, lo_0}; end else if (I_RATIO == 2) begin : g_ratio2 - // fp16: b_col_1 and b_col_2 together form ONE 4-element group - // Select 2 valid from the combined group -> 2 output elements (2x16=32 bits) + // fp16: select 2 valid from combined 4-element group wire [I_RATIO-1:0] mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; wire [I_RATIO-1:0] mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; wire [3:0] grp_mask = {mask_hi, mask_lo}; - // Pool of 4 fp16 elements across 2 registers wire [ELT_W-1:0] elem0 = b_col_1[k][0 +: ELT_W]; wire [ELT_W-1:0] elem1 = b_col_1[k][ELT_W +: ELT_W]; wire [ELT_W-1:0] elem2 = b_col_2[k][0 +: ELT_W]; wire [ELT_W-1:0] elem3 = b_col_2[k][ELT_W +: ELT_W]; - // First valid (scan from LSB) + // First valid (LSB), last valid (MSB) wire [ELT_W-1:0] sel_0 = grp_mask[0] ? elem0 : grp_mask[1] ? elem1 : grp_mask[2] ? elem2 : elem3; - // Last valid (scan from MSB) wire [ELT_W-1:0] sel_1 = grp_mask[3] ? elem3 : grp_mask[2] ? elem2 : grp_mask[1] ? elem1 : elem0; @@ -78,7 +74,7 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( assign b_col[k] = {sel_1, sel_0}; end else if (I_RATIO == 8) begin : g_ratio8 - // int4: each 32-bit register has 8 elements in 2 sub-groups of 4 + // int4: 4 sub-groups of 4 nibbles each wire [I_RATIO-1:0] grp_mask_lo = vld_meta_row[I_RATIO * k +: I_RATIO]; wire [I_RATIO-1:0] grp_mask_hi = vld_meta_row[I_RATIO * (TCU_TC_K + k) +: I_RATIO]; wire [3:0] sg0_mask = grp_mask_lo[3:0]; @@ -86,28 +82,24 @@ module VX_tcu_sel import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [3:0] sg2_mask = grp_mask_hi[3:0]; wire [3:0] sg3_mask = grp_mask_hi[7:4]; - // Sub-group 0: b_col_1 low half [elements 0-3] wire [ELT_W-1:0] sg0_0 = sg0_mask[0] ? b_col_1[k][0*ELT_W +: ELT_W] : sg0_mask[1] ? b_col_1[k][1*ELT_W +: ELT_W] : b_col_1[k][2*ELT_W +: ELT_W]; wire [ELT_W-1:0] sg0_1 = sg0_mask[3] ? b_col_1[k][3*ELT_W +: ELT_W] : sg0_mask[2] ? b_col_1[k][2*ELT_W +: ELT_W] : b_col_1[k][1*ELT_W +: ELT_W]; - // Sub-group 1: b_col_1 high half [elements 4-7] wire [ELT_W-1:0] sg1_0 = sg1_mask[0] ? b_col_1[k][4*ELT_W +: ELT_W] : sg1_mask[1] ? b_col_1[k][5*ELT_W +: ELT_W] : b_col_1[k][6*ELT_W +: ELT_W]; wire [ELT_W-1:0] sg1_1 = sg1_mask[3] ? b_col_1[k][7*ELT_W +: ELT_W] : sg1_mask[2] ? b_col_1[k][6*ELT_W +: ELT_W] : b_col_1[k][5*ELT_W +: ELT_W]; - // Sub-group 2: b_col_2 low half [elements 0-3] wire [ELT_W-1:0] sg2_0 = sg2_mask[0] ? b_col_2[k][0*ELT_W +: ELT_W] : sg2_mask[1] ? b_col_2[k][1*ELT_W +: ELT_W] : b_col_2[k][2*ELT_W +: ELT_W]; wire [ELT_W-1:0] sg2_1 = sg2_mask[3] ? b_col_2[k][3*ELT_W +: ELT_W] : sg2_mask[2] ? b_col_2[k][2*ELT_W +: ELT_W] : b_col_2[k][1*ELT_W +: ELT_W]; - // Sub-group 3: b_col_2 high half [elements 4-7] wire [ELT_W-1:0] sg3_0 = sg3_mask[0] ? b_col_2[k][4*ELT_W +: ELT_W] : sg3_mask[1] ? b_col_2[k][5*ELT_W +: ELT_W] : b_col_2[k][6*ELT_W +: ELT_W]; diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index 132727f76e..a738890777 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -75,7 +75,7 @@ module VX_tcu_uops import : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); - + // Register calculations wire [4:0] rs1 = TCU_RA + 5'(rs1_offset); wire [4:0] rs2 = TCU_RB + 5'(rs2_offset); diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 058ee6da8f..002f3d8708 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -16,7 +16,6 @@ #include #include #include -#include // for vx_printf namespace vortex { namespace tensor { @@ -376,7 +375,7 @@ struct wmma_context { }); } - template + template static __attribute__((always_inline)) void mma_sync(FragD &fragD, const FragA &fragA, const FragB &fragB, const FragC &fragC) { static_assert(FragA::Use == matrix_a, "A must be matrix_a"); static_assert(FragB::Use == matrix_b, "B must be matrix_b"); @@ -424,25 +423,23 @@ struct wmma_context { register float fd6 __asm__("f30"); register float fd7 __asm__("f31"); - __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0" + constexpr int funct3 = sparse ? 1 : 0; + __asm__ volatile (".insn r %[insn], %[f3], 2, x%[fmd], x%[fms], x0" : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + : [insn]"i"(RISCV_CUSTOM0), [f3]"i"(funct3), [fmd]"i"(Ot::id), [fms]"i"(It::id), "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7), "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) ); - // Write results to fragD fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; } else { static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); - // fragB: caller-saved registers (f28-f31) register float fb0 __asm__("f28") = fragB.data[0]; register float fb1 __asm__("f29") = fragB.data[1]; register float fb2 __asm__("f30") = fragB.data[2]; register float fb3 __asm__("f31") = fragB.data[3]; - // fragC: mix of caller-saved (f10-f17) register float fc0 __asm__("f10") = fragC.data[0]; register float fc1 __asm__("f11") = fragC.data[1]; register float fc2 __asm__("f12") = fragC.data[2]; @@ -452,106 +449,6 @@ struct wmma_context { register float fc6 __asm__("f16") = fragC.data[6]; register float fc7 __asm__("f17") = fragC.data[7]; - // Force outputs into accumulator registers - register float fd0 __asm__("f10"); - register float fd1 __asm__("f11"); - register float fd2 __asm__("f12"); - register float fd3 __asm__("f13"); - register float fd4 __asm__("f14"); - register float fd5 __asm__("f15"); - register float fd6 __asm__("f16"); - register float fd7 __asm__("f17"); - - __asm__ volatile (".insn r %[insn], 0, 2, x%[fmd], x%[fms], x0" - : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), - "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), - "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), - "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) - ); - - // Write results to fragD - fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; - } - } - - template - static __attribute__((always_inline)) void mma_struct_sparse_sync(FragD &fragD, const FragA &fragA, const FragB &fragB, const FragC &fragC) { - static_assert(FragA::Use == matrix_a, "A must be matrix_a"); - static_assert(FragB::Use == matrix_b, "B must be matrix_b"); - static_assert(FragC::Use == accumulator, "C must be accumulator"); - static_assert(FragD::Use == accumulator, "D must be accumulator"); - - // fragA: caller-saved registers (f0-f7) - register float fa0 __asm__("f0") = fragA.data[0]; - register float fa1 __asm__("f1") = fragA.data[1]; - register float fa2 __asm__("f2") = fragA.data[2]; - register float fa3 __asm__("f3") = fragA.data[3]; - register float fa4 __asm__("f4") = fragA.data[4]; - register float fa5 __asm__("f5") = fragA.data[5]; - register float fa6 __asm__("f6") = fragA.data[6]; - register float fa7 __asm__("f7") = fragA.data[7]; - - if constexpr (FragB::NR == 8) { - // fragB: caller-saved registers (f10-f17) - register float fb0 __asm__("f10") = fragB.data[0]; - register float fb1 __asm__("f11") = fragB.data[1]; - register float fb2 __asm__("f12") = fragB.data[2]; - register float fb3 __asm__("f13") = fragB.data[3]; - register float fb4 __asm__("f14") = fragB.data[4]; - register float fb5 __asm__("f15") = fragB.data[5]; - register float fb6 __asm__("f16") = fragB.data[6]; - register float fb7 __asm__("f17") = fragB.data[7]; - - // fragC: mix of caller-saved (f28-f31) and callee-saved (f18-f21) - register float fc0 __asm__("f24") = fragC.data[0]; - register float fc1 __asm__("f25") = fragC.data[1]; - register float fc2 __asm__("f26") = fragC.data[2]; - register float fc3 __asm__("f27") = fragC.data[3]; - register float fc4 __asm__("f28") = fragC.data[4]; - register float fc5 __asm__("f29") = fragC.data[5]; - register float fc6 __asm__("f30") = fragC.data[6]; - register float fc7 __asm__("f31") = fragC.data[7]; - - // Force outputs into accumulator registers - register float fd0 __asm__("f24"); - register float fd1 __asm__("f25"); - register float fd2 __asm__("f26"); - register float fd3 __asm__("f27"); - register float fd4 __asm__("f28"); - register float fd5 __asm__("f29"); - register float fd6 __asm__("f30"); - register float fd7 __asm__("f31"); - - __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" - : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), - "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), - "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fb4), "f"(fb5), "f"(fb6), "f"(fb7), - "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) - ); - - // Write results to fragD - fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; - } else { - static_assert(FragB::NR == 4, "Unsupported number of registers for FragB"); - // fragB: caller-saved registers (f28-f31) - register float fb0 __asm__("f28") = fragB.data[0]; - register float fb1 __asm__("f29") = fragB.data[1]; - register float fb2 __asm__("f30") = fragB.data[2]; - register float fb3 __asm__("f31") = fragB.data[3]; - - // fragC: mix of caller-saved (f10-f17) - register float fc0 __asm__("f10") = fragC.data[0]; - register float fc1 __asm__("f11") = fragC.data[1]; - register float fc2 __asm__("f12") = fragC.data[2]; - register float fc3 __asm__("f13") = fragC.data[3]; - register float fc4 __asm__("f14") = fragC.data[4]; - register float fc5 __asm__("f15") = fragC.data[5]; - register float fc6 __asm__("f16") = fragC.data[6]; - register float fc7 __asm__("f17") = fragC.data[7]; - - // Force outputs into accumulator registers register float fd0 __asm__("f10"); register float fd1 __asm__("f11"); register float fd2 __asm__("f12"); @@ -561,41 +458,19 @@ struct wmma_context { register float fd6 __asm__("f16"); register float fd7 __asm__("f17"); - __asm__ volatile (".insn r %[insn], 1, 2, x%[fmd], x%[fms], x0" + constexpr int funct3 = sparse ? 1 : 0; + __asm__ volatile (".insn r %[insn], %[f3], 2, x%[fmd], x%[fms], x0" : "=f"(fd0), "=f"(fd1), "=f"(fd2), "=f"(fd3), "=f"(fd4), "=f"(fd5), "=f"(fd6), "=f"(fd7) - : [insn]"i"(RISCV_CUSTOM0), [fmd]"i"(Ot::id), [fms]"i"(It::id), + : [insn]"i"(RISCV_CUSTOM0), [f3]"i"(funct3), [fmd]"i"(Ot::id), [fms]"i"(It::id), "f"(fa0), "f"(fa1), "f"(fa2), "f"(fa3), "f"(fa4), "f"(fa5), "f"(fa6), "f"(fa7), "f"(fb0), "f"(fb1), "f"(fb2), "f"(fb3), "f"(fc0), "f"(fc1), "f"(fc2), "f"(fc3), "f"(fc4), "f"(fc5), "f"(fc6), "f"(fc7) ); - // Write results to fragD fragD.data = {fd0, fd1, fd2, fd3, fd4, fd5, fd6, fd7}; } } - template - static __attribute__((always_inline)) void mma_sp_sync( - FragD &fragD, - const FragA &fragA, - const FragB &fragB, - const FragC &fragC, - const FragMeta &fragMeta) { - - static_assert(FragA::Use == matrix_a, "A must be matrix_a"); - static_assert(FragB::Use == matrix_b, "B must be matrix_b"); - static_assert(FragC::Use == accumulator, "C must be accumulator"); - static_assert(FragD::Use == accumulator, "D must be accumulator"); - - // placeholder: sparsity path not implemented yet - (void)fragA; - (void)fragB; - (void)fragMeta; - - // NO-OP: keep accumulator unchanged so test will fail (mismatch) but completes cleanly - fragD.data = fragC.data; -} - }; } // namespace tensor diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index 1cee6946ce..0a193e7e0d 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -195,9 +195,9 @@ struct wmma_config_t { static constexpr uint32_t b_sub_blocks = block_cap / b_block_size; // number of B micro-tiles per register static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register - static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // size of B micro-tile (sparse 2:4) - static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; // number of B micro-tiles per register (sparse) - static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; // number of B sub-steps per register (sparse) + static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // sparse 2:4 + static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; + static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; static constexpr uint32_t NRA = (xtileM * xtileK) / NT; // Number of A registers static constexpr uint32_t NRB = (xtileN * xtileK) / NT; // Number of B registers diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 48a50799f4..dee684fad6 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -5,19 +5,6 @@ namespace vt = vortex::tensor; using ctx = vt::wmma_context; -// Decode fp16 bit pattern to value*100 (2 decimal places) using integer math -static inline int32_t fp16_to_x100(uint16_t h) { - uint32_t e = (h >> 10) & 0x1F; - uint32_t m = h & 0x3FF; - if (e == 0) return 0; // zero / subnormal → 0 - // val = 2^(e-15) * (1024+m) / 1024 - // val*100 = (1024+m)*100 * 2^(e-25) - int32_t v = (int32_t)(1024 + m) * 100; - int s = (int)e - 25; - v = (s >= 0) ? (v << s) : (v >> (-s)); - return (h & 0x8000) ? -v : v; -} - void kernel_body(kernel_arg_t *__UNIFORM__ arg) { auto pA = reinterpret_cast(arg->A_addr); auto pB = reinterpret_cast(arg->B_addr); @@ -31,23 +18,17 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::fragment_b fragB; ctx::fragment_acc fragC; - // calculate tile row & column based on block index uint32_t tile_row = blockIdx.y * ctx::tileM; uint32_t tile_col = blockIdx.x * ctx::tileN; - // Initialize accumulator tile to zero ctx::fill_fragment(fragC, 0); uint32_t stride_A = K / 2; for (int i = 0; i < (int)(K / 2); i += (int)(ctx::tileK / 2)) { auto pTileA = pA + tile_row * stride_A + i; - - // Load A tile (compressed: stride = K/2, sparse=true) ctx::load_matrix_sync(fragA, pTileA, stride_A); - // Load B tile (full: uses 2*i to index into original K dimension, sparse=true) if constexpr (vt::ITYPE::bits < 8) { - // For sub-byte matrix B must be in col-major format auto pTileB = pB + tile_col * K + (2 * i); ctx::load_matrix_sync(fragB, pTileB, K); } else { @@ -55,30 +36,9 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::load_matrix_sync(fragB, pTileB, N); } - // if (vx_thread_id() == 0 && blockIdx.x == 0 && blockIdx.y == 0) { - // for (uint32_t r = 0; r < 8; ++r) { - // uint32_t packed; - // asm volatile("fmv.x.w %0, %1" : "=r"(packed) : "f"(fragA.data[r])); - // int32_t lo = fp16_to_x100(packed & 0xFFFF); - // int32_t hi = fp16_to_x100((packed >> 16) & 0xFFFF); - // vx_printf("fragA[%d] | %d.%02d, %d.%02d\n", r, - // lo / 100, lo % 100, hi / 100, hi % 100); - // } - // for (uint32_t r = 0; r < 8; ++r) { - // uint32_t packed; - // asm volatile("fmv.x.w %0, %1" : "=r"(packed) : "f"(fragB.data[r])); - // int32_t lo = fp16_to_x100(packed & 0xFFFF); - // int32_t hi = fp16_to_x100((packed >> 16) & 0xFFFF); - // vx_printf("fragB[%d] | %d.%02d, %d.%02d\n", r, - // lo / 100, lo % 100, hi / 100, hi % 100); - // } - // } - - // Matrix multiply-accumulate: c += a * b (sparse instruction, funct3=1) - ctx::mma_struct_sparse_sync(fragC, fragA, fragB, fragC); + ctx::mma_sync(fragC, fragA, fragB, fragC); } - // Store the computed C tile auto pTileC = pC + tile_row * N + tile_col; ctx::store_matrix_sync(pTileC, fragC, N); } diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 4c2757b5a8..9d426dc5a6 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -641,8 +641,7 @@ using itype_t = typename vt::ITYPE::dtype; using otype_t = typename vt::OTYPE::dtype; -// Dense CPU reference matmul. Works for sparse case too because pruned A -// has zeros at masked positions, so A[m][k]*B[k][n] = 0 naturally. +// Dense CPU reference matmul (pruned A has zeros at masked positions) static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t M, uint32_t N, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; @@ -659,8 +658,7 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t } } -// In-place: zero out elements in full A (M × K) that are NOT selected by -// the fixed 0101/1010 alternating metadata mask. +// Zero out elements in full A not selected by alternating 0101/1010 mask static void prune_fixed_mask(itype_t *A, uint32_t M, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; @@ -681,8 +679,7 @@ static void prune_fixed_mask(itype_t *A, uint32_t M, uint32_t K) { } } -// Extract mask-selected (non-zero) positions from pruned A (M × K) into -// compressed output (M × K/2). +// Compress pruned A (M x K) to M x K/2 by extracting mask-selected positions static void compress_fixed_mask(itype_t *compressed, const itype_t *pruned_A, uint32_t M, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; @@ -768,7 +765,6 @@ void cleanup() { - int main(int argc, char *argv[]) { // parse command arguments parse_args(argc, argv); diff --git a/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py b/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py deleted file mode 100644 index fff759b077..0000000000 --- a/tests/regression/sgemm_tcu_struct_sparse/sparse_test.py +++ /dev/null @@ -1,44 +0,0 @@ -import numpy as np - -def prune_2_4_blockwise_with_mask(matrix): - """ - Perform 2:4 structured sparsity pruning on each row of the input matrix. - For each consecutive block of 4 elements, keep the two largest (by absolute value) - and zero out the rest. - Returns: - pruned: np.ndarray of same shape, with smaller elements zeroed out - mask: np.ndarray of bools, True where elements were kept - """ - pruned = matrix.copy() - mask = np.zeros_like(matrix, dtype=bool) - rows, cols = matrix.shape - - for i in range(rows): - for j in range(0, cols, 4): - block = pruned[i, j:j+4] - # Skip blocks that have fewer than 4 elements (at row end) - if block.shape[0] < 4: - continue - - abs_vals = np.abs(block) - sorted_idx = np.argsort(abs_vals) - top2_idx = sorted_idx[-2:] # Indices of the two largest absolute values - - block_mask = np.zeros_like(block, dtype=bool) - block_mask[top2_idx] = True - - #apply mask: zero out the smaller two elements in the block - pruned[i, j:j+4] = block * block_mask - mask[i, j:j+4] = block_mask - - return pruned, mask - -if __name__ == "__main__": - np.random.seed(42) - matrix = np.random.randn(8, 8) - - pruned_matrix, mask_matrix = prune_2_4_blockwise_with_mask(matrix) - - print("Original matrix:\n", matrix) - print("\nPruned matrix (2:4 structured sparse):\n", pruned_matrix) - print("\nMask matrix (True=kept, False=pruned):\n", mask_matrix) \ No newline at end of file From edd33613f7632bf13b93452262b6edab525ff3a1 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Tue, 10 Feb 2026 14:49:33 -0800 Subject: [PATCH 13/22] loop code change --- tests/regression/sgemm_tcu/main.cpp | 2 +- tests/regression/sgemm_tcu_struct_sparse/kernel.cpp | 8 ++++---- tests/regression/sgemm_tcu_struct_sparse/main.cpp | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/regression/sgemm_tcu/main.cpp b/tests/regression/sgemm_tcu/main.cpp index 9726264703..21fec90f15 100644 --- a/tests/regression/sgemm_tcu/main.cpp +++ b/tests/regression/sgemm_tcu/main.cpp @@ -818,7 +818,7 @@ int main(int argc, char *argv[]) { // upload matrix B buffer { std::cout << "upload matrix B buffer" << std::endl; - if constexpr (std::is_same::value || + if constexpr (std::is_same::value || std::is_same::value || std::is_same::value) { // sub-byte matrix B must be in col-major format diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index dee684fad6..3bf8436e0b 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -24,15 +24,15 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::fill_fragment(fragC, 0); uint32_t stride_A = K / 2; - for (int i = 0; i < (int)(K / 2); i += (int)(ctx::tileK / 2)) { - auto pTileA = pA + tile_row * stride_A + i; + for (int i = 0; i < (int)K; i += (int)ctx::tileK) { + auto pTileA = pA + tile_row * stride_A + (i / 2); ctx::load_matrix_sync(fragA, pTileA, stride_A); if constexpr (vt::ITYPE::bits < 8) { - auto pTileB = pB + tile_col * K + (2 * i); + auto pTileB = pB + tile_col * K + i; ctx::load_matrix_sync(fragB, pTileB, K); } else { - auto pTileB = pB + (2 * i) * N + tile_col; + auto pTileB = pB + i * N + tile_col; ctx::load_matrix_sync(fragB, pTileB, N); } diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 9d426dc5a6..053a94f804 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -871,7 +871,7 @@ int main(int argc, char *argv[]) { // upload matrix B buffer { std::cout << "upload matrix B buffer" << std::endl; - if constexpr (std::is_same::value || + if constexpr (std::is_same::value || std::is_same::value || std::is_same::value) { // sub-byte matrix B must be in col-major format From 1164bfe707ca980fdd6f614c1135b3cbb2428d97 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 13 Feb 2026 11:50:57 -0800 Subject: [PATCH 14/22] NT=16 problem --- hw/rtl/tcu/VX_tcu_core.sv | 38 +++++++++++-- hw/rtl/tcu/VX_tcu_meta.sv | 11 ++-- hw/rtl/tcu/VX_tcu_pkg.sv | 3 +- hw/rtl/tcu/VX_tcu_uops.sv | 54 ++++++++++++------- kernel/include/vx_tensor.h | 4 +- sim/common/tensor_cfg.h | 5 +- .../sgemm_tcu_struct_sparse/main.cpp | 21 ++++++-- 7 files changed, 100 insertions(+), 36 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index f1d8be8635..578df14929 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -78,6 +78,21 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire result_fire = result_if.valid && result_if.ready; wire fedp_enable, fedp_done; + // B_SPLIT: Phase 1 (step_k[0]=0) latches rs2, Phase 2 (step_k[0]=1) computes + wire b_split_phase1 = (TCU_B_SPLIT != 0) & is_sparse & ~step_k[0]; + + // B_SPLIT: per-warp latch for rs2_data (prevents cross-warp corruption) + if (TCU_B_SPLIT) begin : g_bsplit + reg [`NUM_WARPS-1:0][`NUM_TCU_LANES-1:0][`XLEN-1:0] rs2_data_latch; + wire [`LOG2UP(`NUM_WARPS)-1:0] bsplit_wid = execute_if.data.header.wid; + always @(posedge clk) begin + if (reset) + rs2_data_latch <= '0; + else if (execute_fire & b_split_phase1) + rs2_data_latch[bsplit_wid] <= execute_if.data.rs2_data; + end + end + // FEDP delay handling reg [PIPE_LATENCY-1:0] fedp_delay_pipe; always @(posedge clk) begin @@ -118,7 +133,9 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [OFF_W-1:0] a_off = (OFF_W'(step_m) & OFF_W'(TCU_A_SUB_BLOCKS-1)) << LG_A_BS; wire [OFF_W-1:0] b_off = is_sparse - ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP + ? (TCU_B_SPLIT + ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS + : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP) : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; @@ -150,8 +167,20 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); - assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); - assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); + if (TCU_B_SPLIT) begin : g_bsplit_col + // B_SPLIT: pair adjacent lanes within same source (interleaved) + // First half of k uses Phase 1 latch, second half uses Phase 2 rs2 + if (k_idx < (TCU_TC_K / 2)) begin : g_phase1_lane + assign b_col_1[k_idx] = 32'(g_bsplit.rs2_data_latch[g_bsplit.bsplit_wid][b_off + j * TCU_TC_K + k_idx * 2]); + assign b_col_2[k_idx] = 32'(g_bsplit.rs2_data_latch[g_bsplit.bsplit_wid][b_off + j * TCU_TC_K + k_idx * 2 + 1]); + end else begin : g_phase2_lane + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + (k_idx - TCU_TC_K/2) * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + (k_idx - TCU_TC_K/2) * 2 + 1]); + end + end else begin : g_std_col + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); + end end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); @@ -171,7 +200,8 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( ); // Select dense or sparse B column - assign b_col = is_sparse ? b_col_sparse : b_col_dense; + // B_SPLIT Phase 1: zero b_col so FEDP computes 0+c=c (passthrough) + assign b_col = b_split_phase1 ? '0 : (is_sparse ? b_col_sparse : b_col_dense); wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index b8f83e5ee5..95568ed14c 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -36,10 +36,15 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam DEPTH = TCU_M_STEPS * HALF_K_STEPS; localparam ADDRW = `CLOG2(DEPTH); localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); // Bits needed for step_m index - localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); // Bits needed for step_k index (sparse) + localparam K_STEP_BITS = (HALF_K_STEPS > 1) ? `CLOG2(HALF_K_STEPS) : 0; - // Read address: {step_m, step_k} - wire [ADDRW-1:0] read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + // Read address: {step_m, step_k} — step_k omitted when K_STEP_BITS=0 (B_SPLIT) + wire [ADDRW-1:0] read_addr; + if (K_STEP_BITS > 0) begin : g_addr_mk + assign read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + end else begin : g_addr_m + assign read_addr = step_m[M_STEP_BITS-1:0]; + end // Post-reset init: even addr → 0101, odd addr → 1010 reg [ADDRW:0] init_counter; diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index b119b06b0d..091b7cb9d8 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -79,7 +79,8 @@ package VX_tcu_pkg; // B micro-tiling (sparse 2:4) localparam TCU_B_BLOCK_SIZE_SP = (TCU_TC_K * TCU_TC_N) * 2; - localparam TCU_B_SUB_BLOCKS_SP = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP; + localparam TCU_B_SPLIT = (TCU_B_BLOCK_SIZE_SP > TCU_BLOCK_CAP); + localparam TCU_B_SUB_BLOCKS_SP = TCU_B_SPLIT ? 1 : (TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP); // Register counts //localparam TCU_NRA = (TCU_TILE_M * TCU_TILE_K) / TCU_NT; diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index a738890777..c9f284c3c4 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -35,7 +35,7 @@ module VX_tcu_uops import localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); - localparam LG_B_SB_SP = $clog2(TCU_B_SUB_BLOCKS_SP); + localparam LG_B_SB_SP = TCU_B_SPLIT ? 0 : $clog2(TCU_B_SUB_BLOCKS_SP); wire is_sparse_in = (ibuf_in.op_type == INST_TCU_WMMA_SP); reg is_sparse; @@ -47,31 +47,43 @@ module VX_tcu_uops import logic [`UP(LG_M)-1:0] m_index; logic [`UP(LG_K)-1:0] k_index; - if (LG_N != 0) begin : g_n_idx - assign n_index = counter[0 +: LG_N]; - end else begin : g_n_idx0 - assign n_index = 0; - end + if (TCU_B_SPLIT) begin : g_idx_bsplit + // B_SPLIT: when sparse, k iterates fastest so Phase1/Phase2 are consecutive + // when dense, use original order (n fastest) + assign k_index = is_sparse ? counter[0 +: LG_K] : counter[LG_N + LG_M +: LG_K]; + assign n_index = is_sparse ? counter[LG_K +: LG_N] : counter[0 +: LG_N]; + assign m_index = is_sparse ? counter[LG_K + LG_N +: LG_M] : counter[LG_N +: LG_M]; + end else begin : g_idx_normal + if (LG_N != 0) begin : g_n_idx + assign n_index = counter[0 +: LG_N]; + end else begin : g_n_idx0 + assign n_index = 0; + end - if (LG_M != 0) begin : g_m_idx - assign m_index = counter[LG_N +: LG_M]; - end else begin : g_m_idx0 - assign m_index = 0; - end + if (LG_M != 0) begin : g_m_idx + assign m_index = counter[LG_N +: LG_M]; + end else begin : g_m_idx0 + assign m_index = 0; + end - if (LG_K != 0) begin : g_k_idx - assign k_index = counter[LG_N + LG_M +: LG_K]; - end else begin : g_k_idx0 - assign k_index = 0; + if (LG_K != 0) begin : g_k_idx + assign k_index = counter[LG_N + LG_M +: LG_K]; + end else begin : g_k_idx0 + assign k_index = 0; + end end // Register offsets — dense vs sparse formulas wire [CTR_W-1:0] rs1_offset = is_sparse - ? ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index) - : ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); + ? (TCU_B_SPLIT + ? (CTR_W'(m_index) >> LG_A_SB) + : ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index)) + : ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); wire [CTR_W-1:0] rs2_offset = is_sparse - ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP + ? (TCU_B_SPLIT + ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB + : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP) : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); @@ -126,10 +138,12 @@ module VX_tcu_uops import counter <= 0; busy <= 1; is_sparse <= is_sparse_in; - done <= is_sparse_in ? (TCU_UOPS/2 == 1) : (TCU_UOPS == 1); + done <= (is_sparse_in && !TCU_B_SPLIT) + ? (TCU_UOPS/2 == 1) + : (TCU_UOPS == 1); end else if (busy && next) begin counter <= counter + ((TCU_UOPS > 1) ? 1 : 0); - done <= is_sparse + done <= (is_sparse && !TCU_B_SPLIT) ? (counter == CTR_W'((TCU_UOPS/2)-2)) : (counter == CTR_W'(TCU_UOPS-2)); busy <= ~done; diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 002f3d8708..08394b109c 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -247,8 +247,8 @@ struct wmma_context { }); } } else if constexpr (Frag::Use == matrix_b) { - if constexpr (sparse) { - // Sparse B load: uses 2x tcK for B block + if constexpr (sparse && !cfg::b_split) { + // Sparse B load (non-B_SPLIT): uses 2x tcK for B block constexpr uint32_t b_tcK = cfg::tcK * 2; uint32_t block_idx = (cfg::b_block_size_sp == NT) ? 0 : (lane / cfg::b_block_size_sp); uint32_t lane_in_blk = (cfg::b_block_size_sp == NT) ? lane : (lane % cfg::b_block_size_sp); diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index 0a193e7e0d..91f5d22fd4 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -196,8 +196,9 @@ struct wmma_config_t { static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // sparse 2:4 - static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; - static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; + static constexpr bool b_split = (b_block_size_sp > NT); + static constexpr uint32_t b_sub_blocks_sp = b_split ? 1 : (block_cap / b_block_size_sp); + static constexpr uint32_t b_sub_steps_sp = b_split ? 0 : (n_steps / b_sub_blocks_sp); static constexpr uint32_t NRA = (xtileM * xtileK) / NT; // Number of A registers static constexpr uint32_t NRB = (xtileN * xtileK) / NT; // Number of B registers diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 053a94f804..9d41ad4625 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -668,8 +668,15 @@ static void prune_fixed_mask(itype_t *A, uint32_t M, uint32_t K) { for (uint32_t m = 0; m < M; ++m) { for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { uint32_t k_start = k1 * 4; - uint32_t pos_in_tile = k_start % tile_k_elem; - uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + uint8_t meta_mask; + if constexpr (cfg::b_split) { + // B_SPLIT: metadata addresses by step_m (alternating by row group) + uint32_t step_m = (m % cfg::tileM) / cfg::tcM; + meta_mask = (step_m & 1) ? 0b1010 : 0b0101; + } else { + uint32_t pos_in_tile = k_start % tile_k_elem; + meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + } for (uint32_t k2 = 0; k2 < 4; ++k2) { if (!(meta_mask & (1 << k2))) { data_accessor_t::write(A, m * KS + k_start + k2, 0); @@ -692,8 +699,14 @@ static void compress_fixed_mask(itype_t *compressed, const itype_t *pruned_A, uint32_t a_out = 0; for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { uint32_t k_start = k1 * 4; - uint32_t pos_in_tile = k_start % tile_k_elem; - uint8_t meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + uint8_t meta_mask; + if constexpr (cfg::b_split) { + uint32_t step_m = (m % cfg::tileM) / cfg::tcM; + meta_mask = (step_m & 1) ? 0b1010 : 0b0101; + } else { + uint32_t pos_in_tile = k_start % tile_k_elem; + meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + } for (uint32_t k2 = 0; k2 < 4; ++k2) { if (meta_mask & (1 << k2)) { auto val = data_accessor_t::read(pruned_A, m * KS + k_start + k2); From 6a203aca7a8f0d0f8dd0e032f34c17acb2395931 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Sun, 15 Feb 2026 22:23:15 -0800 Subject: [PATCH 15/22] meta_store new SRAM feeding instruction --- hw/rtl/VX_gpu_pkg.sv | 5 +- hw/rtl/core/VX_decode.sv | 9 +++ hw/rtl/tcu/VX_tcu_core.sv | 32 ++++++++- hw/rtl/tcu/VX_tcu_meta.sv | 71 ++++++++++--------- hw/rtl/tcu/VX_tcu_pkg.sv | 3 + kernel/include/vx_tensor.h | 17 +++++ sim/common/tensor_cfg.h | 7 ++ .../sgemm_tcu_struct_sparse/common.h | 1 + .../sgemm_tcu_struct_sparse/kernel.cpp | 5 ++ .../sgemm_tcu_struct_sparse/main.cpp | 23 ++++++ 10 files changed, 135 insertions(+), 38 deletions(-) diff --git a/hw/rtl/VX_gpu_pkg.sv b/hw/rtl/VX_gpu_pkg.sv index 641054f119..068041a336 100644 --- a/hw/rtl/VX_gpu_pkg.sv +++ b/hw/rtl/VX_gpu_pkg.sv @@ -460,8 +460,9 @@ package VX_gpu_pkg; `ifdef EXT_TCU_ENABLE - localparam INST_TCU_WMMA = 4'h0; - localparam INST_TCU_WMMA_SP = 4'h1; + localparam INST_TCU_WMMA = 4'h0; + localparam INST_TCU_WMMA_SP = 4'h1; + localparam INST_TCU_META_STORE = 4'h2; localparam INST_TCU_BITS = 4; `endif diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 4a876a1464..16e286a658 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -561,6 +561,15 @@ module VX_decode import VX_gpu_pkg::*; #( `USED_FREG (rs1); `USED_FREG (rs2); `USED_FREG (rs3); + end else if (funct3 == 3'h2) begin + ex_type = EX_TCU; + op_type = INST_OP_BITS'(INST_TCU_META_STORE); + op_args.tcu.fmt_d = rd[3:0]; // col_idx + op_args.tcu.fmt_s = '0; + op_args.tcu.step_m = '0; + op_args.tcu.step_n = '0; + op_args.tcu.step_k = '0; + `USED_FREG (rs1); // source float register end end `endif diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 578df14929..7dc6db7457 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -62,6 +62,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam OFF_W = $clog2(TCU_BLOCK_CAP); wire is_sparse = (execute_if.data.op_type == INST_TCU_WMMA_SP); + wire is_meta_store = (execute_if.data.op_type == INST_TCU_META_STORE); wire [3:0] step_m = execute_if.data.op_args.tcu.step_m; wire [3:0] step_n = execute_if.data.op_args.tcu.step_n; @@ -70,6 +71,25 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [3:0] fmt_s = execute_if.data.op_args.tcu.fmt_s; wire [3:0] fmt_d = execute_if.data.op_args.tcu.fmt_d; + wire [`LOG2UP(`NUM_WARPS)-1:0] wid = execute_if.data.header.wid; + + // meta_store: extract per-row write data from rs1_data lanes + localparam PER_WARP_DEPTH = TCU_M_STEPS * (TCU_K_STEPS / 2); + wire meta_wr_en = execute_fire && is_meta_store; + wire [PER_WARP_DEPTH-1:0][31:0] meta_wr_data; + for (genvar r = 0; r < PER_WARP_DEPTH; ++r) begin : g_meta_wr + assign meta_wr_data[r] = 32'(execute_if.data.rs1_data[r]); + end + + // meta_store: force rd=0 in mdata_queue header (x0 write is harmless) + tcu_header_t mdata_queue_in; + always_comb begin + mdata_queue_in = execute_if.data.header; + if (is_meta_store) begin + mdata_queue_in.rd = '0; + end + end + `UNUSED_VAR ({step_m, step_n, step_k, fmt_s, fmt_d, execute_if.data}); wire mdata_queue_full; @@ -122,7 +142,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( .reset (reset), .push (execute_fire), .pop (result_fire), - .data_in(execute_if.data.header), + .data_in(mdata_queue_in), .data_out(result_if.data.header), `UNUSED_PIN(empty), `UNUSED_PIN(alm_empty), @@ -152,13 +172,19 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( VX_tcu_meta #( .INSTANCE_ID (INSTANCE_ID), - .META_BLOCK_WIDTH(META_BLOCK_WIDTH) + .META_BLOCK_WIDTH(META_BLOCK_WIDTH), + .PER_WARP_DEPTH (PER_WARP_DEPTH) ) tcu_meta ( .clk (clk), .reset (reset), + .raddr_wid (wid), .step_m (step_m), .step_k (step_k), - .vld_meta_block(vld_meta_block) + .vld_meta_block(vld_meta_block), + .wr_en (meta_wr_en), + .wr_wid (wid), + .wr_col_idx (fmt_d), + .wr_data (meta_wr_data) ); for (genvar i = 0; i < TCU_TC_M; ++i) begin : g_i diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index 95568ed14c..ccd6d534a5 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -17,36 +17,51 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( parameter `STRING INSTANCE_ID = "", - parameter META_BLOCK_WIDTH = 64 // Default: TCU_NT * 2 * I_RATIO + parameter META_BLOCK_WIDTH = 64, + parameter PER_WARP_DEPTH = 4 ) ( input wire clk, input wire reset, - // Step indices (from VX_tcu_core) + // Read port (from FEDP path) + input wire [`LOG2UP(`NUM_WARPS)-1:0] raddr_wid, input wire [3:0] step_m, input wire [3:0] step_k, + output wire [META_BLOCK_WIDTH-1:0] vld_meta_block, - // Output (combinational) - output wire [META_BLOCK_WIDTH-1:0] vld_meta_block + // Write port (meta_store instruction) + input wire wr_en, + input wire [`LOG2UP(`NUM_WARPS)-1:0] wr_wid, + input wire [3:0] wr_col_idx, + input wire [PER_WARP_DEPTH-1:0][31:0] wr_data ); `UNUSED_SPARAM (INSTANCE_ID) // Local parameters localparam HALF_K_STEPS = TCU_K_STEPS / 2; - localparam DEPTH = TCU_M_STEPS * HALF_K_STEPS; - localparam ADDRW = `CLOG2(DEPTH); - localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); // Bits needed for step_m index - localparam K_STEP_BITS = (HALF_K_STEPS > 1) ? `CLOG2(HALF_K_STEPS) : 0; + localparam TOTAL_DEPTH = `NUM_WARPS * PER_WARP_DEPTH; + localparam ADDRW = `CLOG2(TOTAL_DEPTH); + localparam ADDRW_PW = `CLOG2(PER_WARP_DEPTH); + localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); + localparam K_STEP_BITS = (HALF_K_STEPS > 1) ? `CLOG2(HALF_K_STEPS) : 0; + localparam NUM_COLS = META_BLOCK_WIDTH / 32; - // Read address: {step_m, step_k} — step_k omitted when K_STEP_BITS=0 (B_SPLIT) - wire [ADDRW-1:0] read_addr; + // Metadata register array (per-warp partitioned) + reg [META_BLOCK_WIDTH-1:0] meta_mem [0:TOTAL_DEPTH-1]; + + // Read address: {wid, step_m, step_k} + wire [ADDRW_PW-1:0] per_warp_raddr; if (K_STEP_BITS > 0) begin : g_addr_mk - assign read_addr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + assign per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; end else begin : g_addr_m - assign read_addr = step_m[M_STEP_BITS-1:0]; + assign per_warp_raddr = step_m[M_STEP_BITS-1:0]; end + wire [ADDRW-1:0] read_addr = {raddr_wid, per_warp_raddr}; + + // Combinational read + assign vld_meta_block = meta_mem[read_addr]; - // Post-reset init: even addr → 0101, odd addr → 1010 + // Post-reset init counter: fills all warps with alternating patterns reg [ADDRW:0] init_counter; wire init_active = ~init_counter[ADDRW]; wire [ADDRW-1:0] init_addr = init_counter[ADDRW-1:0]; @@ -54,34 +69,24 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( {(META_BLOCK_WIDTH/4){4'b1010}} : {(META_BLOCK_WIDTH/4){4'b0101}}; + // Write logic: init or runtime meta_store always_ff @(posedge clk) begin if (reset) begin init_counter <= 0; end else if (init_active) begin + meta_mem[init_addr] <= init_data; init_counter <= init_counter + 1; + end else if (wr_en) begin + for (int row = 0; row < PER_WARP_DEPTH; row++) begin + for (int col = 0; col < NUM_COLS; col++) begin + if (col == int'(wr_col_idx)) begin + meta_mem[{wr_wid, ADDRW_PW'(row)}][col*32 +: 32] <= wr_data[row]; + end + end + end end end - // Metadata SRAM (combinational read) - VX_dp_ram #( - .DATAW (META_BLOCK_WIDTH), - .SIZE (DEPTH), - .WRENW (1), - .OUT_REG (0), // Combinational read: output same cycle as address - .RDW_MODE ("R"), - .INIT_ENABLE (0) - ) meta_store ( - .clk (clk), - .reset (1'b0), - .read (1'b1), // Always enabled (combinational) - .write (init_active), - .wren (1'b1), - .waddr (init_addr), - .wdata (init_data), - .raddr (read_addr), - .rdata (vld_meta_block) - ); - endmodule /* verilator lint_on UNUSEDSIGNAL */ diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index 091b7cb9d8..e11e6f2777 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -183,6 +183,9 @@ package VX_tcu_pkg; trace_fmt(level, op_args.tcu.fmt_d); `TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n)); end + INST_TCU_META_STORE: begin + `TRACE(level, ("META_STORE.col%0d", op_args.tcu.fmt_d)); + end default: `TRACE(level, ("?")) endcase endtask diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 08394b109c..69cb46f55f 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -375,6 +375,23 @@ struct wmma_context { }); } + template + static __attribute__((always_inline)) void meta_store(float data) { + __asm__ volatile(".insn r 0x0b, 2, 2, x%[col], %[data], x0" + :: [col]"i"(COL), [data]"f"(data)); + } + + static __attribute__((always_inline)) void load_metadata_sync(const void* meta_ptr) { + constexpr uint32_t rtl_i_ratio = 32 / It::bits; + constexpr uint32_t num_cols = (NT * 2 * rtl_i_ratio) / 32; + uint32_t lane_id = vx_thread_id(); + auto base = reinterpret_cast(meta_ptr); + detail::unroll_for([&](auto col) { + float data = base[lane_id * num_cols + col]; + meta_store(data); + }); + } + template static __attribute__((always_inline)) void mma_sync(FragD &fragD, const FragA &fragA, const FragB &fragB, const FragC &fragC) { static_assert(FragA::Use == matrix_a, "A must be matrix_a"); diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index 91f5d22fd4..ddf3b4e4d2 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -222,6 +222,13 @@ struct wmma_config_t { static constexpr uint32_t tileM = xtileM; static constexpr uint32_t tileN = xtileN; static constexpr uint32_t tileK = xtileK * i_ratio; // Adjusted for input type size + + // Metadata constants for 2:4 structured sparsity + static constexpr uint32_t itype_bits = It::bits; + static constexpr uint32_t rtl_i_ratio = 32 / itype_bits; + static constexpr uint32_t meta_block_width = NT * 2 * rtl_i_ratio; // bits + static constexpr uint32_t meta_cols = meta_block_width / 32; + static constexpr uint32_t per_warp_depth = m_steps * (k_steps / 2); }; } // namespace tensor diff --git a/tests/regression/sgemm_tcu_struct_sparse/common.h b/tests/regression/sgemm_tcu_struct_sparse/common.h index a762a4fb2e..478fe9b733 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/common.h +++ b/tests/regression/sgemm_tcu_struct_sparse/common.h @@ -22,6 +22,7 @@ typedef struct { uint64_t A_addr; uint64_t B_addr; uint64_t C_addr; + uint64_t meta_addr; } kernel_arg_t; #endif diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 3bf8436e0b..21d79202b1 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -9,11 +9,16 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { auto pA = reinterpret_cast(arg->A_addr); auto pB = reinterpret_cast(arg->B_addr); auto pC = reinterpret_cast(arg->C_addr); + auto pMeta = reinterpret_cast(arg->meta_addr); uint32_t M = arg->M; uint32_t N = arg->N; uint32_t K = arg->K; + // Phase 1: Load metadata into SRAM (once per tile) + ctx::load_metadata_sync(pMeta); + + // Phase 2: Compute ctx::fragment_a fragA; ctx::fragment_b fragB; ctx::fragment_acc fragC; diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 9d41ad4625..3491d06122 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -730,6 +730,7 @@ vx_device_h device = nullptr; vx_buffer_h A_buffer = nullptr; vx_buffer_h B_buffer = nullptr; vx_buffer_h C_buffer = nullptr; +vx_buffer_h meta_buffer = nullptr; vx_buffer_h krnl_buffer = nullptr; vx_buffer_h args_buffer = nullptr; kernel_arg_t kernel_arg = {}; @@ -770,6 +771,7 @@ void cleanup() { vx_mem_free(A_buffer); vx_mem_free(B_buffer); vx_mem_free(C_buffer); + vx_mem_free(meta_buffer); vx_mem_free(krnl_buffer); vx_mem_free(args_buffer); vx_dev_close(device); @@ -856,9 +858,17 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + // allocate metadata buffer (padded to NT rows for all lanes) + constexpr uint32_t meta_cols = cfg::meta_cols; + constexpr uint32_t per_warp_depth = cfg::per_warp_depth; + constexpr uint32_t meta_buf_entries = NUM_THREADS * meta_cols; + RT_CHECK(vx_mem_alloc(device, meta_buf_entries * sizeof(uint32_t), VX_MEM_READ, &meta_buffer)); + RT_CHECK(vx_mem_address(meta_buffer, &kernel_arg.meta_addr)); + std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; + std::cout << "meta_addr=0x" << std::hex << kernel_arg.meta_addr << std::endl; // generate source data // Generate full matrix A (M × K), prune in-place, then compress to M × K/2 @@ -897,6 +907,19 @@ int main(int argc, char *argv[]) { } } + // upload metadata buffer + { + std::cout << "upload metadata buffer" << std::endl; + std::vector h_meta(meta_buf_entries, 0); + for (uint32_t row = 0; row < per_warp_depth; ++row) { + uint32_t pattern = (row & 1) ? 0xAAAAAAAA : 0x55555555; + for (uint32_t col = 0; col < meta_cols; ++col) { + h_meta[row * meta_cols + col] = pattern; + } + } + RT_CHECK(vx_copy_to_dev(meta_buffer, h_meta.data(), 0, meta_buf_entries * sizeof(uint32_t))); + } + // upload program std::cout << "upload program" << std::endl; RT_CHECK(vx_upload_kernel_file(device, kernel_file, &krnl_buffer)); From 5bc74fb1e8abc3d81516a0e221bb4c3dd04f171c Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Mon, 16 Feb 2026 07:35:48 -0800 Subject: [PATCH 16/22] real meta, dynamic meta generation and run --- .../sgemm_tcu_struct_sparse/kernel.cpp | 18 +- .../sgemm_tcu_struct_sparse/main.cpp | 328 +++++++++++++++--- 2 files changed, 295 insertions(+), 51 deletions(-) diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 21d79202b1..3fdec5061e 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -9,16 +9,12 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { auto pA = reinterpret_cast(arg->A_addr); auto pB = reinterpret_cast(arg->B_addr); auto pC = reinterpret_cast(arg->C_addr); - auto pMeta = reinterpret_cast(arg->meta_addr); + auto pMetaBase = reinterpret_cast(arg->meta_addr); uint32_t M = arg->M; uint32_t N = arg->N; uint32_t K = arg->K; - // Phase 1: Load metadata into SRAM (once per tile) - ctx::load_metadata_sync(pMeta); - - // Phase 2: Compute ctx::fragment_a fragA; ctx::fragment_b fragB; ctx::fragment_acc fragC; @@ -28,8 +24,20 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::fill_fragment(fragC, 0); + // Per-K-tile metadata reload + constexpr uint32_t rtl_i_ratio = 32 / vt::ITYPE::bits; + constexpr uint32_t meta_cols = (NUM_THREADS * 2 * rtl_i_ratio) / 32; + constexpr uint32_t per_k_tile_words = NUM_THREADS * meta_cols; + uint32_t num_k_tiles = K / ctx::tileK; + uint32_t tile_row_idx = blockIdx.y; + uint32_t stride_A = K / 2; for (int i = 0; i < (int)K; i += (int)ctx::tileK) { + // Load metadata for this K-tile + uint32_t k_tile = i / ctx::tileK; + auto pMeta = pMetaBase + (tile_row_idx * num_k_tiles + k_tile) * per_k_tile_words; + ctx::load_metadata_sync(pMeta); + auto pTileA = pA + tile_row * stride_A + (i / 2); ctx::load_matrix_sync(fragA, pTileA, stride_A); diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 3491d06122..398a0588a4 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -658,57 +658,85 @@ static void matmul_cpu(otype_t *C, const itype_t *A, const itype_t *B, uint32_t } } -// Zero out elements in full A not selected by alternating 0101/1010 mask -static void prune_fixed_mask(itype_t *A, uint32_t M, uint32_t K) { +// Get magnitude of element at given offset in A matrix (for pruning comparison) +static float get_element_magnitude(const itype_t *A, uint32_t offset) { + auto val = data_accessor_t::read(A, offset); + if constexpr (std::is_same_v || std::is_same_v) { + return std::abs(static_cast(static_cast(val))); + } else if constexpr (std::is_same_v) { + return static_cast(val); + } else if constexpr (std::is_same_v) { + int32_t sval = val & 0xF; + if (sval & 0x8) sval |= ~0xF; + return std::abs(static_cast(sval)); + } else if constexpr (std::is_same_v) { + return static_cast(val & 0xF); + } else if constexpr (std::is_same_v) { + return std::abs(bit_cast(rv_htof_s(val, 0, nullptr))); + } else if constexpr (std::is_same_v) { + return std::abs(bit_cast(rv_btof_s(val, 0, nullptr))); + } else { + return std::abs(static_cast(val)); + } +} + +// Prune matrix A with real 2:4 structured sparsity (top-2 by magnitude per group of 4) +// Zeros pruned elements in-place and stores per-group 4-bit masks +static void prune_2to4(itype_t *A, std::vector &masks, uint32_t M, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; - uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; - uint32_t half_tile = tile_k_elem / 2; + uint32_t num_groups = KS / 4; + masks.resize(M * num_groups); for (uint32_t m = 0; m < M; ++m) { - for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { - uint32_t k_start = k1 * 4; - uint8_t meta_mask; - if constexpr (cfg::b_split) { - // B_SPLIT: metadata addresses by step_m (alternating by row group) - uint32_t step_m = (m % cfg::tileM) / cfg::tcM; - meta_mask = (step_m & 1) ? 0b1010 : 0b0101; - } else { - uint32_t pos_in_tile = k_start % tile_k_elem; - meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; + for (uint32_t g = 0; g < num_groups; ++g) { + uint32_t k_start = g * 4; + + // Get magnitudes + float mags[4]; + for (int p = 0; p < 4; ++p) { + mags[p] = get_element_magnitude(A, m * KS + k_start + p); } - for (uint32_t k2 = 0; k2 < 4; ++k2) { - if (!(meta_mask & (1 << k2))) { - data_accessor_t::write(A, m * KS + k_start + k2, 0); + + // Find indices of top-2 by magnitude (ties broken by lower index) + int top[2] = {0, 1}; + if (mags[1] > mags[0]) { top[0] = 1; top[1] = 0; } + for (int p = 2; p < 4; ++p) { + if (mags[p] > mags[top[0]]) { + top[1] = top[0]; + top[0] = p; + } else if (mags[p] > mags[top[1]]) { + top[1] = p; + } + } + + // Build mask and zero pruned elements + uint8_t mask = (1 << top[0]) | (1 << top[1]); + masks[m * num_groups + g] = mask; + for (int p = 0; p < 4; ++p) { + if (!(mask & (1 << p))) { + data_accessor_t::write(A, m * KS + k_start + p, 0); } } } } } -// Compress pruned A (M x K) to M x K/2 by extracting mask-selected positions -static void compress_fixed_mask(itype_t *compressed, const itype_t *pruned_A, - uint32_t M, uint32_t K) { +// Compress pruned A (M x K) to M x K/2 using per-group masks +static void compress_2to4(itype_t *compressed, const itype_t *pruned_A, + const std::vector &masks, uint32_t M, uint32_t K) { uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; uint32_t KS = subbytes ? (K * subbytes) : K; uint32_t stride_comp = KS / 2; - uint32_t tile_k_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; - uint32_t half_tile = tile_k_elem / 2; + uint32_t num_groups = KS / 4; for (uint32_t m = 0; m < M; ++m) { uint32_t a_out = 0; - for (uint32_t k1 = 0; k1 < (KS / 4); ++k1) { - uint32_t k_start = k1 * 4; - uint8_t meta_mask; - if constexpr (cfg::b_split) { - uint32_t step_m = (m % cfg::tileM) / cfg::tcM; - meta_mask = (step_m & 1) ? 0b1010 : 0b0101; - } else { - uint32_t pos_in_tile = k_start % tile_k_elem; - meta_mask = (pos_in_tile < half_tile) ? 0b0101 : 0b1010; - } + for (uint32_t g = 0; g < num_groups; ++g) { + uint32_t k_start = g * 4; + uint8_t mask = masks[m * num_groups + g]; for (uint32_t k2 = 0; k2 < 4; ++k2) { - if (meta_mask & (1 << k2)) { + if (mask & (1 << k2)) { auto val = data_accessor_t::read(pruned_A, m * KS + k_start + k2); data_accessor_t::write(compressed, m * stride_comp + a_out, val); a_out++; @@ -718,6 +746,73 @@ static void compress_fixed_mask(itype_t *compressed, const itype_t *pruned_A, } } +// Pack per-group masks into VX_tcu_meta SRAM layout +// Output: h_meta vector indexed as [tile_row][k_tile][NT * meta_cols words] +static void pack_metadata(std::vector &h_meta, + const std::vector &masks, + uint32_t M, uint32_t K) { + constexpr uint32_t I_RATIO = cfg::rtl_i_ratio; + constexpr uint32_t TC_K = cfg::tcK; + constexpr uint32_t TC_M = cfg::tcM; + constexpr uint32_t meta_row_w = TC_K * 2 * I_RATIO; + constexpr uint32_t mcols = cfg::meta_cols; + constexpr uint32_t half_k_steps = cfg::k_steps / 2; + + uint32_t subbytes = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t tileK_elem = subbytes ? (cfg::tileK * subbytes) : cfg::tileK; + uint32_t KS = subbytes ? (K * subbytes) : K; + uint32_t num_groups_per_row = KS / 4; + uint32_t elts_per_sparse_step = tileK_elem / half_k_steps; + + uint32_t num_tile_rows = M / cfg::tileM; + uint32_t num_k_tiles = K / cfg::tileK; + uint32_t per_k_tile_words = NUM_THREADS * mcols; + + h_meta.assign(num_tile_rows * num_k_tiles * per_k_tile_words, 0); + + for (uint32_t tr = 0; tr < num_tile_rows; ++tr) { + for (uint32_t kt = 0; kt < num_k_tiles; ++kt) { + uint32_t section_base = (tr * num_k_tiles + kt) * per_k_tile_words; + + for (uint32_t sm = 0; sm < cfg::m_steps; ++sm) { + for (uint32_t sk = 0; sk < half_k_steps; ++sk) { + uint32_t sram_row = sm * half_k_steps + sk; + + for (uint32_t i = 0; i < TC_M; ++i) { + uint32_t physical_row = tr * cfg::tileM + sm * TC_M + i; + uint32_t k_elem_start = kt * tileK_elem + sk * elts_per_sparse_step; + uint32_t groups_in_step = elts_per_sparse_step / 4; + + for (uint32_t g = 0; g < groups_in_step; ++g) { + uint32_t global_group = (k_elem_start / 4) + g; + uint8_t mask = masks[physical_row * num_groups_per_row + global_group]; + + for (int p = 0; p < 4; ++p) { + if (mask & (1 << p)) { + // Map element position to meta_row bit position + uint32_t elt = g * 4 + p; + uint32_t k_reg = elt / (2 * I_RATIO); + uint32_t pos_in_k = elt % (2 * I_RATIO); + uint32_t meta_bit; + if (pos_in_k < I_RATIO) { + meta_bit = k_reg * I_RATIO + pos_in_k; + } else { + meta_bit = (TC_K + k_reg) * I_RATIO + (pos_in_k - I_RATIO); + } + uint32_t block_bit = i * meta_row_w + meta_bit; + uint32_t word_idx = block_bit / 32; + uint32_t bit_idx = block_bit % 32; + h_meta[section_base + sram_row * mcols + word_idx] |= (1u << bit_idx); + } + } + } + } + } + } + } + } +} + /////////////////////////////////////////////////////////////////////////////// const char *kernel_file = "kernel.vxbin"; @@ -858,10 +953,11 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); - // allocate metadata buffer (padded to NT rows for all lanes) + // allocate metadata buffer per (tile_row, k_tile) constexpr uint32_t meta_cols = cfg::meta_cols; - constexpr uint32_t per_warp_depth = cfg::per_warp_depth; - constexpr uint32_t meta_buf_entries = NUM_THREADS * meta_cols; + uint32_t num_tile_rows = M / cfg::tileM; + uint32_t num_k_tiles = K / cfg::tileK; + uint32_t meta_buf_entries = num_tile_rows * num_k_tiles * NUM_THREADS * meta_cols; RT_CHECK(vx_mem_alloc(device, meta_buf_entries * sizeof(uint32_t), VX_MEM_READ, &meta_buffer)); RT_CHECK(vx_mem_address(meta_buffer, &kernel_arg.meta_addr)); @@ -876,9 +972,10 @@ int main(int argc, char *argv[]) { for (uint32_t i = 0; i < sizeA_full; ++i) { h_A_full[i] = generate_A_value(); } - prune_fixed_mask(h_A_full.data(), M, K); + std::vector masks; + prune_2to4(h_A_full.data(), masks, M, K); std::vector h_A(sizeA); - compress_fixed_mask(h_A.data(), h_A_full.data(), M, K); + compress_2to4(h_A.data(), h_A_full.data(), masks, M, K); std::vector h_B(sizeB); for (uint32_t i = 0; i < sizeB; ++i) { @@ -907,16 +1004,11 @@ int main(int argc, char *argv[]) { } } - // upload metadata buffer + // upload metadata buffer (real masks from pruning) { std::cout << "upload metadata buffer" << std::endl; - std::vector h_meta(meta_buf_entries, 0); - for (uint32_t row = 0; row < per_warp_depth; ++row) { - uint32_t pattern = (row & 1) ? 0xAAAAAAAA : 0x55555555; - for (uint32_t col = 0; col < meta_cols; ++col) { - h_meta[row * meta_cols + col] = pattern; - } - } + std::vector h_meta; + pack_metadata(h_meta, masks, M, K); RT_CHECK(vx_copy_to_dev(meta_buffer, h_meta.data(), 0, meta_buf_entries * sizeof(uint32_t))); } @@ -947,6 +1039,70 @@ int main(int argc, char *argv[]) { std::cout << "download destination buffer" << std::endl; RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); + // === DEBUG: dump masks, metadata, compressed A for row 0 === + { + uint32_t subbytes_d = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS_d = subbytes_d ? (K * subbytes_d) : K; + uint32_t num_groups_d = KS_d / 4; + std::cout << "=== DEBUG: ITYPE::bits=" << vt::ITYPE::bits + << " I_RATIO=" << cfg::rtl_i_ratio + << " TC_K=" << cfg::tcK << " TC_M=" << cfg::tcM + << " meta_cols=" << cfg::meta_cols + << " tileK=" << cfg::tileK + << " k_steps=" << cfg::k_steps + << " half_k_steps=" << cfg::k_steps/2 + << std::endl; + + // Print masks for row 0 + std::cout << "Masks row 0:"; + for (uint32_t g = 0; g < num_groups_d && g < 8; ++g) { + printf(" g%u=0x%x", g, masks[0 * num_groups_d + g]); + } + std::cout << std::endl; + + // Print compressed A for row 0 (first 8 elements) + uint32_t stride_comp_d = KS_d / 2; + std::cout << "Compressed A row 0 (hex):"; + for (uint32_t k = 0; k < stride_comp_d && k < 16; ++k) { + auto val = data_accessor_t::read(h_A.data(), 0 * stride_comp_d + k); + printf(" 0x%x", (unsigned)val); + } + std::cout << std::endl; + + // Recompute and print metadata words + std::vector h_meta_dbg; + pack_metadata(h_meta_dbg, masks, M, K); + constexpr uint32_t mcols_d = cfg::meta_cols; + uint32_t per_k_words_d = NUM_THREADS * mcols_d; + std::cout << "Metadata words (tile_row=0, k_tile=0):"; + for (uint32_t w = 0; w < per_k_words_d; ++w) { + printf(" [%u]=0x%08x", w, h_meta_dbg[w]); + } + std::cout << std::endl; + + // Decode metadata bits for sram_row 0 (sm=0, sk=0) + // Each sram_row has mcols_d words = mcols_d*32 bits + // TC_M rows, each META_ROW_WIDTH bits + constexpr uint32_t meta_row_w_d = cfg::tcK * 2 * cfg::rtl_i_ratio; + std::cout << " sram_row0 decoded (TC_M=" << cfg::tcM << " rows, " << meta_row_w_d << " bits each):" << std::endl; + uint32_t sram0_word = h_meta_dbg[0]; + for (uint32_t i = 0; i < cfg::tcM; ++i) { + uint32_t row_bits = (sram0_word >> (i * meta_row_w_d)) & ((1u << meta_row_w_d) - 1); + printf(" TC_M row %u: bits=0x%x (binary:", i, row_bits); + for (int b = meta_row_w_d-1; b >= 0; --b) printf("%d", (row_bits >> b) & 1); + printf(")\n"); + } + + // Show what pruned A looks like for row 0 (full K) + std::cout << "Pruned A row 0 (full, first 16 hex):"; + for (uint32_t k = 0; k < KS_d && k < 16; ++k) { + auto val = data_accessor_t::read(h_A_full.data(), 0 * KS_d + k); + printf(" 0x%x", (unsigned)val); + } + std::cout << std::endl; + } + // === END DEBUG === + // verify result std::cout << "verify result" << std::endl; int errors = 0; @@ -954,6 +1110,86 @@ int main(int argc, char *argv[]) { std::vector h_ref(sizeC); matmul_cpu(h_ref.data(), h_A_full.data(), h_B.data(), M, N, K); + // Sparse reference: manually compute using compressed A + mask-selected B + // This mimics exactly what the hardware should do + uint32_t subbytes_v = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; + uint32_t KS_v = subbytes_v ? (K * subbytes_v) : K; + uint32_t stride_comp_v = KS_v / 2; + uint32_t num_groups_v = KS_v / 4; + std::vector h_sparse_ref(sizeC); + for (uint32_t m = 0; m < M; ++m) { + for (uint32_t n = 0; n < N; ++n) { + otype_t sum(0); + uint32_t comp_idx = 0; + for (uint32_t g = 0; g < num_groups_v; ++g) { + uint8_t mask = masks[m * num_groups_v + g]; + // Extract first set and last set positions (matching VX_tcu_sel) + int first_set = -1, last_set = -1; + for (int p = 0; p < 4; ++p) { + if (mask & (1 << p)) { + if (first_set < 0) first_set = p; + last_set = p; + } + } + uint32_t k_base = g * 4; + // compressed A stores in ascending order: first_set then last_set + auto a_first = data_accessor_t::read(h_A.data(), m * stride_comp_v + comp_idx); + auto a_last = data_accessor_t::read(h_A.data(), m * stride_comp_v + comp_idx + 1); + auto b_first = data_accessor_t::read(h_B.data(), (k_base + first_set) * N + n); + auto b_last = data_accessor_t::read(h_B.data(), (k_base + last_set) * N + n); + sum = muladd_t::eval(a_first, b_first, sum); + sum = muladd_t::eval(a_last, b_last, sum); + comp_idx += 2; + } + data_accessor_t::write(h_sparse_ref.data(), m * N + n, sum); + } + } + + // Compare sparse ref with dense ref (should match) + int sparse_ref_errors = 0; + for (uint32_t i = 0; i < sizeC; ++i) { + if (!Comparator::compare(h_sparse_ref[i], h_ref[i], i, sparse_ref_errors)) { + if (sparse_ref_errors <= 5) { + printf(" sparse_ref[%u]=%f vs cpu_ref[%u]=%f\n", i, + static_cast(h_sparse_ref[i]), i, + static_cast(h_ref[i])); + } + ++sparse_ref_errors; + } + } + if (sparse_ref_errors > 0) { + printf("WARNING: sparse_ref vs cpu_ref: %d / %u mismatches!\n", sparse_ref_errors, sizeC); + } else { + printf("sparse_ref vs cpu_ref: ALL MATCH\n"); + } + + // Compare GPU output with sparse ref + int gpu_vs_sparse = 0; + for (uint32_t i = 0; i < sizeC; ++i) { + if (!Comparator::compare(h_C[i], h_sparse_ref[i], i, gpu_vs_sparse)) { + if (gpu_vs_sparse <= 5) { + printf(" gpu[%u]=%f vs sparse_ref[%u]=%f\n", i, + static_cast(h_C[i]), i, + static_cast(h_sparse_ref[i])); + } + ++gpu_vs_sparse; + } + } + if (gpu_vs_sparse > 0) { + printf("GPU vs sparse_ref: %d / %u mismatches\n", gpu_vs_sparse, sizeC); + } else { + printf("GPU vs sparse_ref: ALL MATCH\n"); + } + + // Print first few entries for manual inspection + printf("First 8 entries: cpu_ref / sparse_ref / gpu\n"); + for (uint32_t i = 0; i < 8 && i < sizeC; ++i) { + printf(" [%u] %f / %f / %f\n", i, + static_cast(h_ref[i]), + static_cast(h_sparse_ref[i]), + static_cast(h_C[i])); + } + for (uint32_t i = 0; i < h_ref.size(); ++i) { if (!Comparator::compare(h_C[i], h_ref[i], i, errors)) { ++errors; From 348187e7d09afd82f28c79ebd92b14fc9c2f3de2 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Mon, 16 Feb 2026 08:55:08 -0800 Subject: [PATCH 17/22] separate the tcu only time using csr hardware count --- tests/regression/sgemm_tcu/common.h | 1 + tests/regression/sgemm_tcu/kernel.cpp | 6 ++++++ tests/regression/sgemm_tcu/main.cpp | 17 +++++++++++++++++ .../regression/sgemm_tcu_struct_sparse/common.h | 1 + .../sgemm_tcu_struct_sparse/kernel.cpp | 6 ++++++ .../regression/sgemm_tcu_struct_sparse/main.cpp | 17 +++++++++++++++++ 6 files changed, 48 insertions(+) diff --git a/tests/regression/sgemm_tcu/common.h b/tests/regression/sgemm_tcu/common.h index a762a4fb2e..a11916c616 100644 --- a/tests/regression/sgemm_tcu/common.h +++ b/tests/regression/sgemm_tcu/common.h @@ -22,6 +22,7 @@ typedef struct { uint64_t A_addr; uint64_t B_addr; uint64_t C_addr; + uint64_t tcu_cycles_addr; } kernel_arg_t; #endif diff --git a/tests/regression/sgemm_tcu/kernel.cpp b/tests/regression/sgemm_tcu/kernel.cpp index 1d722cbac7..1c3f0ef762 100644 --- a/tests/regression/sgemm_tcu/kernel.cpp +++ b/tests/regression/sgemm_tcu/kernel.cpp @@ -1,6 +1,7 @@ #include "common.h" #include #include +#include namespace vt = vortex::tensor; using ctx = vt::wmma_context; @@ -25,6 +26,7 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { // Initialize accumulator tile to zero ctx::fill_fragment(fragC, 0); + uint32_t cyc_start = csr_read(0xB00); for (int i = 0; i < K; i += ctx::tileK) { auto pTileA = pA + tile_row * K + i; @@ -44,6 +46,10 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { // Matrix multiply-accumulate: c += a * b ctx::mma_sync(fragC, fragA, fragB, fragC); } + uint32_t cyc_end = csr_read(0xB00); + auto pCycles = reinterpret_cast(arg->tcu_cycles_addr); + uint32_t block_id = blockIdx.y * arg->grid_dim[0] + blockIdx.x; + pCycles[block_id] = cyc_end - cyc_start; // Store the computed C tile auto pTileC = pC + tile_row * N + tile_col; diff --git a/tests/regression/sgemm_tcu/main.cpp b/tests/regression/sgemm_tcu/main.cpp index 21fec90f15..f7235c8b1b 100644 --- a/tests/regression/sgemm_tcu/main.cpp +++ b/tests/regression/sgemm_tcu/main.cpp @@ -669,6 +669,7 @@ vx_device_h device = nullptr; vx_buffer_h A_buffer = nullptr; vx_buffer_h B_buffer = nullptr; vx_buffer_h C_buffer = nullptr; +vx_buffer_h cycles_buffer = nullptr; vx_buffer_h krnl_buffer = nullptr; vx_buffer_h args_buffer = nullptr; kernel_arg_t kernel_arg = {}; @@ -709,6 +710,7 @@ void cleanup() { vx_mem_free(A_buffer); vx_mem_free(B_buffer); vx_mem_free(C_buffer); + vx_mem_free(cycles_buffer); vx_mem_free(krnl_buffer); vx_mem_free(args_buffer); vx_dev_close(device); @@ -795,6 +797,10 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_mem_alloc(device, sizeC * sizeof(otype_t), VX_MEM_WRITE, &C_buffer)); RT_CHECK(vx_mem_address(C_buffer, &kernel_arg.C_addr)); + uint32_t num_blocks = kernel_arg.grid_dim[0] * kernel_arg.grid_dim[1]; + RT_CHECK(vx_mem_alloc(device, num_blocks * sizeof(uint32_t), VX_MEM_WRITE, &cycles_buffer)); + RT_CHECK(vx_mem_address(cycles_buffer, &kernel_arg.tcu_cycles_addr)); + std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; @@ -858,6 +864,17 @@ int main(int argc, char *argv[]) { std::cout << "download destination buffer" << std::endl; RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); + // download TCU K-loop cycle counts + { + std::vector h_cycles(num_blocks); + RT_CHECK(vx_copy_from_dev(h_cycles.data(), cycles_buffer, 0, num_blocks * sizeof(uint32_t))); + uint32_t max_cyc = 0; + for (uint32_t i = 0; i < num_blocks; ++i) { + if (h_cycles[i] > max_cyc) max_cyc = h_cycles[i]; + } + printf("TCU_CYCLES: max=%u (across %u blocks)\n", max_cyc, num_blocks); + } + // verify result std::cout << "verify result" << std::endl; int errors = 0; diff --git a/tests/regression/sgemm_tcu_struct_sparse/common.h b/tests/regression/sgemm_tcu_struct_sparse/common.h index 478fe9b733..eaaf6b5fbf 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/common.h +++ b/tests/regression/sgemm_tcu_struct_sparse/common.h @@ -23,6 +23,7 @@ typedef struct { uint64_t B_addr; uint64_t C_addr; uint64_t meta_addr; + uint64_t tcu_cycles_addr; } kernel_arg_t; #endif diff --git a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp index 3fdec5061e..e79c3c7150 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/kernel.cpp @@ -1,6 +1,7 @@ #include "common.h" #include #include +#include namespace vt = vortex::tensor; using ctx = vt::wmma_context; @@ -32,6 +33,7 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { uint32_t tile_row_idx = blockIdx.y; uint32_t stride_A = K / 2; + uint32_t cyc_start = csr_read(0xB00); for (int i = 0; i < (int)K; i += (int)ctx::tileK) { // Load metadata for this K-tile uint32_t k_tile = i / ctx::tileK; @@ -51,6 +53,10 @@ void kernel_body(kernel_arg_t *__UNIFORM__ arg) { ctx::mma_sync(fragC, fragA, fragB, fragC); } + uint32_t cyc_end = csr_read(0xB00); + auto pCycles = reinterpret_cast(arg->tcu_cycles_addr); + uint32_t block_id = blockIdx.y * arg->grid_dim[0] + blockIdx.x; + pCycles[block_id] = cyc_end - cyc_start; auto pTileC = pC + tile_row * N + tile_col; ctx::store_matrix_sync(pTileC, fragC, N); diff --git a/tests/regression/sgemm_tcu_struct_sparse/main.cpp b/tests/regression/sgemm_tcu_struct_sparse/main.cpp index 398a0588a4..1b39316f9b 100644 --- a/tests/regression/sgemm_tcu_struct_sparse/main.cpp +++ b/tests/regression/sgemm_tcu_struct_sparse/main.cpp @@ -826,6 +826,7 @@ vx_buffer_h A_buffer = nullptr; vx_buffer_h B_buffer = nullptr; vx_buffer_h C_buffer = nullptr; vx_buffer_h meta_buffer = nullptr; +vx_buffer_h cycles_buffer = nullptr; vx_buffer_h krnl_buffer = nullptr; vx_buffer_h args_buffer = nullptr; kernel_arg_t kernel_arg = {}; @@ -867,6 +868,7 @@ void cleanup() { vx_mem_free(B_buffer); vx_mem_free(C_buffer); vx_mem_free(meta_buffer); + vx_mem_free(cycles_buffer); vx_mem_free(krnl_buffer); vx_mem_free(args_buffer); vx_dev_close(device); @@ -961,6 +963,10 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_mem_alloc(device, meta_buf_entries * sizeof(uint32_t), VX_MEM_READ, &meta_buffer)); RT_CHECK(vx_mem_address(meta_buffer, &kernel_arg.meta_addr)); + uint32_t num_blocks = kernel_arg.grid_dim[0] * kernel_arg.grid_dim[1]; + RT_CHECK(vx_mem_alloc(device, num_blocks * sizeof(uint32_t), VX_MEM_WRITE, &cycles_buffer)); + RT_CHECK(vx_mem_address(cycles_buffer, &kernel_arg.tcu_cycles_addr)); + std::cout << "A_addr=0x" << std::hex << kernel_arg.A_addr << std::endl; std::cout << "B_addr=0x" << std::hex << kernel_arg.B_addr << std::endl; std::cout << "C_addr=0x" << std::hex << kernel_arg.C_addr << std::endl; @@ -1039,6 +1045,17 @@ int main(int argc, char *argv[]) { std::cout << "download destination buffer" << std::endl; RT_CHECK(vx_copy_from_dev(h_C.data(), C_buffer, 0, sizeC * sizeof(otype_t))); + // download TCU K-loop cycle counts + { + std::vector h_cycles(num_blocks); + RT_CHECK(vx_copy_from_dev(h_cycles.data(), cycles_buffer, 0, num_blocks * sizeof(uint32_t))); + uint32_t max_cyc = 0; + for (uint32_t i = 0; i < num_blocks; ++i) { + if (h_cycles[i] > max_cyc) max_cyc = h_cycles[i]; + } + printf("TCU_CYCLES: max=%u (across %u blocks)\n", max_cyc, num_blocks); + } + // === DEBUG: dump masks, metadata, compressed A for row 0 === { uint32_t subbytes_d = (vt::ITYPE::bits < 8) ? (8 / vt::ITYPE::bits) : 0; From 0613bb23f23471bbee95aa7837981e2d13866f11 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Tue, 17 Feb 2026 19:46:50 -0800 Subject: [PATCH 18/22] past NT=16 clean up --- hw/rtl/tcu/VX_tcu_core.sv | 38 +++------------------------- hw/rtl/tcu/VX_tcu_meta.sv | 11 +++----- hw/rtl/tcu/VX_tcu_pkg.sv | 3 +-- hw/rtl/tcu/VX_tcu_uops.sv | 52 ++++++++++++++------------------------ kernel/include/vx_tensor.h | 4 +-- sim/common/tensor_cfg.h | 5 ++-- 6 files changed, 31 insertions(+), 82 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index 7dc6db7457..61e3567386 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -98,21 +98,6 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire result_fire = result_if.valid && result_if.ready; wire fedp_enable, fedp_done; - // B_SPLIT: Phase 1 (step_k[0]=0) latches rs2, Phase 2 (step_k[0]=1) computes - wire b_split_phase1 = (TCU_B_SPLIT != 0) & is_sparse & ~step_k[0]; - - // B_SPLIT: per-warp latch for rs2_data (prevents cross-warp corruption) - if (TCU_B_SPLIT) begin : g_bsplit - reg [`NUM_WARPS-1:0][`NUM_TCU_LANES-1:0][`XLEN-1:0] rs2_data_latch; - wire [`LOG2UP(`NUM_WARPS)-1:0] bsplit_wid = execute_if.data.header.wid; - always @(posedge clk) begin - if (reset) - rs2_data_latch <= '0; - else if (execute_fire & b_split_phase1) - rs2_data_latch[bsplit_wid] <= execute_if.data.rs2_data; - end - end - // FEDP delay handling reg [PIPE_LATENCY-1:0] fedp_delay_pipe; always @(posedge clk) begin @@ -153,9 +138,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( wire [OFF_W-1:0] a_off = (OFF_W'(step_m) & OFF_W'(TCU_A_SUB_BLOCKS-1)) << LG_A_BS; wire [OFF_W-1:0] b_off = is_sparse - ? (TCU_B_SPLIT - ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS - : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP) + ? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP : (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS; wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val; @@ -193,20 +176,8 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]); assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]); - if (TCU_B_SPLIT) begin : g_bsplit_col - // B_SPLIT: pair adjacent lanes within same source (interleaved) - // First half of k uses Phase 1 latch, second half uses Phase 2 rs2 - if (k_idx < (TCU_TC_K / 2)) begin : g_phase1_lane - assign b_col_1[k_idx] = 32'(g_bsplit.rs2_data_latch[g_bsplit.bsplit_wid][b_off + j * TCU_TC_K + k_idx * 2]); - assign b_col_2[k_idx] = 32'(g_bsplit.rs2_data_latch[g_bsplit.bsplit_wid][b_off + j * TCU_TC_K + k_idx * 2 + 1]); - end else begin : g_phase2_lane - assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + (k_idx - TCU_TC_K/2) * 2]); - assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + (k_idx - TCU_TC_K/2) * 2 + 1]); - end - end else begin : g_std_col - assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); - assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); - end + assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]); + assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); @@ -226,8 +197,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( ); // Select dense or sparse B column - // B_SPLIT Phase 1: zero b_col so FEDP computes 0+c=c (passthrough) - assign b_col = b_split_phase1 ? '0 : (is_sparse ? b_col_sparse : b_col_dense); + assign b_col = is_sparse ? b_col_sparse : b_col_dense; wire [3:0] fmt_s_r, fmt_d_r; wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r; diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index ccd6d534a5..1f74dd9a24 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -43,19 +43,14 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam ADDRW = `CLOG2(TOTAL_DEPTH); localparam ADDRW_PW = `CLOG2(PER_WARP_DEPTH); localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); - localparam K_STEP_BITS = (HALF_K_STEPS > 1) ? `CLOG2(HALF_K_STEPS) : 0; + localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); localparam NUM_COLS = META_BLOCK_WIDTH / 32; // Metadata register array (per-warp partitioned) reg [META_BLOCK_WIDTH-1:0] meta_mem [0:TOTAL_DEPTH-1]; - // Read address: {wid, step_m, step_k} - wire [ADDRW_PW-1:0] per_warp_raddr; - if (K_STEP_BITS > 0) begin : g_addr_mk - assign per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; - end else begin : g_addr_m - assign per_warp_raddr = step_m[M_STEP_BITS-1:0]; - end + // Read address: {wid, step_m, step_k} + wire [ADDRW_PW-1:0] per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; wire [ADDRW-1:0] read_addr = {raddr_wid, per_warp_raddr}; // Combinational read diff --git a/hw/rtl/tcu/VX_tcu_pkg.sv b/hw/rtl/tcu/VX_tcu_pkg.sv index e11e6f2777..89dfb1025f 100644 --- a/hw/rtl/tcu/VX_tcu_pkg.sv +++ b/hw/rtl/tcu/VX_tcu_pkg.sv @@ -79,8 +79,7 @@ package VX_tcu_pkg; // B micro-tiling (sparse 2:4) localparam TCU_B_BLOCK_SIZE_SP = (TCU_TC_K * TCU_TC_N) * 2; - localparam TCU_B_SPLIT = (TCU_B_BLOCK_SIZE_SP > TCU_BLOCK_CAP); - localparam TCU_B_SUB_BLOCKS_SP = TCU_B_SPLIT ? 1 : (TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP); + localparam TCU_B_SUB_BLOCKS_SP = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP; // Register counts //localparam TCU_NRA = (TCU_TILE_M * TCU_TILE_K) / TCU_NT; diff --git a/hw/rtl/tcu/VX_tcu_uops.sv b/hw/rtl/tcu/VX_tcu_uops.sv index c9f284c3c4..40f0353e9f 100644 --- a/hw/rtl/tcu/VX_tcu_uops.sv +++ b/hw/rtl/tcu/VX_tcu_uops.sv @@ -35,7 +35,7 @@ module VX_tcu_uops import localparam LG_A_SB = $clog2(TCU_A_SUB_BLOCKS); localparam LG_B_SB = $clog2(TCU_B_SUB_BLOCKS); - localparam LG_B_SB_SP = TCU_B_SPLIT ? 0 : $clog2(TCU_B_SUB_BLOCKS_SP); + localparam LG_B_SB_SP = $clog2(TCU_B_SUB_BLOCKS_SP); wire is_sparse_in = (ibuf_in.op_type == INST_TCU_WMMA_SP); reg is_sparse; @@ -47,43 +47,31 @@ module VX_tcu_uops import logic [`UP(LG_M)-1:0] m_index; logic [`UP(LG_K)-1:0] k_index; - if (TCU_B_SPLIT) begin : g_idx_bsplit - // B_SPLIT: when sparse, k iterates fastest so Phase1/Phase2 are consecutive - // when dense, use original order (n fastest) - assign k_index = is_sparse ? counter[0 +: LG_K] : counter[LG_N + LG_M +: LG_K]; - assign n_index = is_sparse ? counter[LG_K +: LG_N] : counter[0 +: LG_N]; - assign m_index = is_sparse ? counter[LG_K + LG_N +: LG_M] : counter[LG_N +: LG_M]; - end else begin : g_idx_normal - if (LG_N != 0) begin : g_n_idx - assign n_index = counter[0 +: LG_N]; - end else begin : g_n_idx0 - assign n_index = 0; - end + if (LG_N != 0) begin : g_n_idx + assign n_index = counter[0 +: LG_N]; + end else begin : g_n_idx0 + assign n_index = 0; + end - if (LG_M != 0) begin : g_m_idx - assign m_index = counter[LG_N +: LG_M]; - end else begin : g_m_idx0 - assign m_index = 0; - end + if (LG_M != 0) begin : g_m_idx + assign m_index = counter[LG_N +: LG_M]; + end else begin : g_m_idx0 + assign m_index = 0; + end - if (LG_K != 0) begin : g_k_idx - assign k_index = counter[LG_N + LG_M +: LG_K]; - end else begin : g_k_idx0 - assign k_index = 0; - end + if (LG_K != 0) begin : g_k_idx + assign k_index = counter[LG_N + LG_M +: LG_K]; + end else begin : g_k_idx0 + assign k_index = 0; end // Register offsets — dense vs sparse formulas wire [CTR_W-1:0] rs1_offset = is_sparse - ? (TCU_B_SPLIT - ? (CTR_W'(m_index) >> LG_A_SB) - : ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index)) + ? ((CTR_W'(m_index) >> LG_A_SB) << (LG_K/2)) | CTR_W'(k_index) : ((CTR_W'(m_index) >> LG_A_SB) << LG_K) | CTR_W'(k_index); wire [CTR_W-1:0] rs2_offset = is_sparse - ? (TCU_B_SPLIT - ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB - : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP) + ? ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB_SP : ((CTR_W'(k_index) << LG_N) | CTR_W'(n_index)) >> LG_B_SB; wire [CTR_W-1:0] rs3_offset = (CTR_W'(m_index) << LG_N) | CTR_W'(n_index); @@ -138,12 +126,10 @@ module VX_tcu_uops import counter <= 0; busy <= 1; is_sparse <= is_sparse_in; - done <= (is_sparse_in && !TCU_B_SPLIT) - ? (TCU_UOPS/2 == 1) - : (TCU_UOPS == 1); + done <= is_sparse_in ? (TCU_UOPS/2 == 1) : (TCU_UOPS == 1); end else if (busy && next) begin counter <= counter + ((TCU_UOPS > 1) ? 1 : 0); - done <= (is_sparse && !TCU_B_SPLIT) + done <= is_sparse ? (counter == CTR_W'((TCU_UOPS/2)-2)) : (counter == CTR_W'(TCU_UOPS-2)); busy <= ~done; diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 69cb46f55f..6201e2f2df 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -247,8 +247,8 @@ struct wmma_context { }); } } else if constexpr (Frag::Use == matrix_b) { - if constexpr (sparse && !cfg::b_split) { - // Sparse B load (non-B_SPLIT): uses 2x tcK for B block + if constexpr (sparse) { + // Sparse B load: uses 2x tcK for B block constexpr uint32_t b_tcK = cfg::tcK * 2; uint32_t block_idx = (cfg::b_block_size_sp == NT) ? 0 : (lane / cfg::b_block_size_sp); uint32_t lane_in_blk = (cfg::b_block_size_sp == NT) ? lane : (lane % cfg::b_block_size_sp); diff --git a/sim/common/tensor_cfg.h b/sim/common/tensor_cfg.h index ddf3b4e4d2..f633888c7b 100644 --- a/sim/common/tensor_cfg.h +++ b/sim/common/tensor_cfg.h @@ -196,9 +196,8 @@ struct wmma_config_t { static constexpr uint32_t b_sub_steps = n_steps / b_sub_blocks; // number of B sub-steps per register static constexpr uint32_t b_block_size_sp = (tcK * tcN) * 2; // sparse 2:4 - static constexpr bool b_split = (b_block_size_sp > NT); - static constexpr uint32_t b_sub_blocks_sp = b_split ? 1 : (block_cap / b_block_size_sp); - static constexpr uint32_t b_sub_steps_sp = b_split ? 0 : (n_steps / b_sub_blocks_sp); + static constexpr uint32_t b_sub_blocks_sp = block_cap / b_block_size_sp; + static constexpr uint32_t b_sub_steps_sp = n_steps / b_sub_blocks_sp; static constexpr uint32_t NRA = (xtileM * xtileK) / NT; // Number of A registers static constexpr uint32_t NRB = (xtileN * xtileK) / NT; // Number of B registers From 741750a0ace84975f8c1962516313dd53f40ed16 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Thu, 19 Feb 2026 22:33:15 -0800 Subject: [PATCH 19/22] comment from professor --- kernel/include/vx_tensor.h | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/kernel/include/vx_tensor.h b/kernel/include/vx_tensor.h index 6201e2f2df..6f816a89f5 100644 --- a/kernel/include/vx_tensor.h +++ b/kernel/include/vx_tensor.h @@ -377,10 +377,16 @@ struct wmma_context { template static __attribute__((always_inline)) void meta_store(float data) { - __asm__ volatile(".insn r 0x0b, 2, 2, x%[col], %[data], x0" + __asm__ volatile(".insn r 0x0b, 2, 2, x%[col], %[data], x0" // RISCV_CUSTOM0 instead of 0b :: [col]"i"(COL), [data]"f"(data)); } +// // Set thread mask // "memory" comment stop compiler reordering. +// inline void vx_tmc(int thread_mask) { +// __asm__ volatile (".insn r %0, 0, 0, x0, %1, x0" :: "i"(RISCV_CUSTOM0), "r"(thread_mask) : "memory"); +// } + + static __attribute__((always_inline)) void load_metadata_sync(const void* meta_ptr) { constexpr uint32_t rtl_i_ratio = 32 / It::bits; constexpr uint32_t num_cols = (NT * 2 * rtl_i_ratio) / 32; From 9a2c32d896981151dd5d2534655b285e6a8a4ecc Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Thu, 19 Feb 2026 23:49:31 -0800 Subject: [PATCH 20/22] fix verilator lint warning for vld_mask after upstream merge Co-Authored-By: Claude Opus 4.6 --- hw/rtl/tcu/VX_tcu_core.sv | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hw/rtl/tcu/VX_tcu_core.sv b/hw/rtl/tcu/VX_tcu_core.sv index fd54e0974f..8400674208 100644 --- a/hw/rtl/tcu/VX_tcu_core.sv +++ b/hw/rtl/tcu/VX_tcu_core.sv @@ -180,8 +180,9 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #( assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]); end wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]); - + /* verilator lint_off UNUSEDSIGNAL */ wire [TCU_MAX_INPUTS-1:0] vld_mask = '1; // TODO: should connect to input source + /* verilator lint_on UNUSEDSIGNAL */ wire [META_ROW_WIDTH-1:0] vld_meta_row = vld_meta_block[META_ROW_WIDTH*i +: META_ROW_WIDTH]; VX_tcu_sel #( From 8bfc6b618ea7c52871b8192ab947bc1b00bf2b2f Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Fri, 20 Feb 2026 00:52:32 -0800 Subject: [PATCH 21/22] fix SimX get_barrier_phase for global barriers The function was missing global barrier flag handling (bit 31 of bar_id). Other barrier functions in emulator.cpp already route global barriers to socket->get_barrier_phase(), but this getter did not, causing SimX test failures after the upstream barrier instruction merge. Co-Authored-By: Claude Opus 4.6 --- sim/simx/emulator.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sim/simx/emulator.cpp b/sim/simx/emulator.cpp index bbb3e6c654..e055cc39eb 100644 --- a/sim/simx/emulator.cpp +++ b/sim/simx/emulator.cpp @@ -276,6 +276,11 @@ bool Emulator::wspawn(uint32_t num_warps, Word nextPC) { } uint32_t Emulator::get_barrier_phase(uint32_t bar_id) const { + bool is_global = (bar_id >> 31); + bar_id &= 0x7fffffff; + if (is_global) { + return core_->socket()->get_barrier_phase(bar_id); + } return barriers_.at(bar_id).phase; } From 9cc55a47d0bfb979949cdbee7240c44c273d8999 Mon Sep 17 00:00:00 2001 From: yanggon-kim Date: Mon, 23 Feb 2026 21:51:56 -0800 Subject: [PATCH 22/22] fix VX_tcu_meta address: use generate-if bit-concatenation instead of arithmetic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the multiplier-based address calculation (step_m * HALF_K_STEPS + step_k) with a generate-if selecting pure bit-concatenation at elaboration time. This fixes the Verilator SELRANGE error when HALF_K_STEPS=1 (e.g. NT=4) without introducing combinational logic — all paths are wire routing only. Co-Authored-By: Claude Opus 4.6 --- hw/rtl/tcu/VX_tcu_meta.sv | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/hw/rtl/tcu/VX_tcu_meta.sv b/hw/rtl/tcu/VX_tcu_meta.sv index 1f74dd9a24..21033636fd 100644 --- a/hw/rtl/tcu/VX_tcu_meta.sv +++ b/hw/rtl/tcu/VX_tcu_meta.sv @@ -42,15 +42,28 @@ module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #( localparam TOTAL_DEPTH = `NUM_WARPS * PER_WARP_DEPTH; localparam ADDRW = `CLOG2(TOTAL_DEPTH); localparam ADDRW_PW = `CLOG2(PER_WARP_DEPTH); - localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); - localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); localparam NUM_COLS = META_BLOCK_WIDTH / 32; // Metadata register array (per-warp partitioned) reg [META_BLOCK_WIDTH-1:0] meta_mem [0:TOTAL_DEPTH-1]; - // Read address: {wid, step_m, step_k} - wire [ADDRW_PW-1:0] per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + // Read address: bit-concatenation of step_m and step_k (pure wire routing, zero delay) + // Use generate-if to avoid zero-width bit-selects when a dimension has only 1 step + localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS); + localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS); + + wire [ADDRW_PW-1:0] per_warp_raddr; + generate + if (K_STEP_BITS > 0 && M_STEP_BITS > 0) begin : g_addr_mk + assign per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]}; + end else if (K_STEP_BITS > 0) begin : g_addr_k + assign per_warp_raddr = step_k[K_STEP_BITS-1:0]; + end else if (M_STEP_BITS > 0) begin : g_addr_m + assign per_warp_raddr = step_m[M_STEP_BITS-1:0]; + end else begin : g_addr_zero + assign per_warp_raddr = '0; + end + endgenerate wire [ADDRW-1:0] read_addr = {raddr_wid, per_warp_raddr}; // Combinational read