From ad748dadd66f6e0e9620d95dfa5b172ed67f28b0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 9 Dec 2025 15:58:03 -0600 Subject: [PATCH 01/13] GEMM reference HIP implementation --- tests/cpp/operator/test_cublaslt_gemm.cu | 309 ++++++++++++++++++----- 1 file changed, 245 insertions(+), 64 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 071470bdf..e1e0b9316 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -51,11 +51,224 @@ using TShape = std::vector; } // namespace -float ref_gelu(float x){ +__device__ __host__ __forceinline__ float ref_gelu(float x){ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output) +{ + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + + if (ii >= m || jj >= n) + return; + + float val = 0.0f; + + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); + + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = a_data[a_idx]; + const float b_val = b_data[b_idx]; + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) { + val += (float)bias_data[ii]; + } + + if (gelu_data) { + gelu_data[ii + jj * m] = val; + val = ref_gelu(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = scaled; + + if (is_fp8_output && d_amax) { + atomicMax(d_amax, fabsf(val)); + } +} + +// Common implementation used by both tensor-wise and MXFP8 frontends +template +static void compute_ref_impl( + const A_Type* a_data, + const B_Type* b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* b_scale_inv_mxfp8, + const Bias_Type* bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* d_data, + float* d_amax_host, + Gelu_Type* gelu_data, + bool transa, + bool transb) +{ + using transformer_engine::DType; + using ::TypeInfo; + using ::isFp8Type; + + const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + + const DType dtype = TypeInfo::dtype; + const bool is_fp8_output = isFp8Type(dtype); + + const size_t lenA = m * k; + const size_t lenB = k * n; + const size_t lenD = m * n; + const size_t lenBias = m; + const size_t lenGelu = m * n; + + const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; + const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + + A_Type* dA = nullptr; + B_Type* dB = nullptr; + Bias_Type* dBias = nullptr; + D_Type* dD = nullptr; + Gelu_Type* dGelu = nullptr; + float* dAmax = nullptr; + fp8e8m0* dA_scale = nullptr; + fp8e8m0* dB_scale = nullptr; + + // Allocations and H2D transfers + NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + + NVTE_CHECK_CUDA(cudaMemcpy( + dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + + if (bias_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); + NVTE_CHECK_CUDA(cudaMemcpy( + dBias, bias_data, lenBias * sizeof(Bias_Type), + cudaMemcpyHostToDevice)); + } + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); + NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + } + + if (use_mxfp8) { + NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMemcpy( + dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + } + + // Kernel launch + dim3 block(16, 16); + dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + + compute_ref_kernel + <<>>( + dA, + dB, + a_scale_inv_scalar, + b_scale_inv_scalar, + dA_scale, + dB_scale, + dBias, + d_scale, + m, k, n, + dD, + dAmax, + dGelu, + transa, + transb, + is_fp8_output); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // D2H copies + NVTE_CHECK_CUDA(cudaMemcpy( + d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMemcpy( + gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMemcpy( + d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); + } else if (d_amax_host) { + *d_amax_host = 0.0f; + } + + // cleanup + NVTE_CHECK_CUDA(cudaFree(dA)); + NVTE_CHECK_CUDA(cudaFree(dB)); + NVTE_CHECK_CUDA(cudaFree(dD)); + if (dBias) + NVTE_CHECK_CUDA(cudaFree(dBias)); + if (dGelu) + NVTE_CHECK_CUDA(cudaFree(dGelu)); + if (dAmax) + NVTE_CHECK_CUDA(cudaFree(dAmax)); + if (dA_scale) + NVTE_CHECK_CUDA(cudaFree(dA_scale)); + if (dB_scale) + NVTE_CHECK_CUDA(cudaFree(dB_scale)); +} + + template void compute_ref( const A_Type* a_data, @@ -71,36 +284,21 @@ void compute_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m]; - float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k]; - val += a_scale_inv*a_val*b_scale_inv*b_val; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/a_scale_inv, + /*b_scale_inv_scalar=*/b_scale_inv, + /*a_scale_inv_mxfp8=*/nullptr, + /*b_scale_inv_mxfp8=*/nullptr, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -118,38 +316,21 @@ void compute_mxfp8_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); - size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); - float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); - float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); - val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/1.0f, + /*b_scale_inv_scalar=*/1.0f, + /*a_scale_inv_mxfp8=*/a_scale_inv_data, + /*b_scale_inv_mxfp8=*/b_scale_inv_data, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -371,7 +552,7 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the gemm in CPU + //perform the reference gemm on GPU std::unique_ptr ref_D = std::make_unique(params.m*params.n); std::unique_ptr ref_pre_gelu_out; if(params.use_gelu){ From 11e090b9e34f0fc792122e232af4e2b863122ef6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 11 Dec 2025 15:14:53 -0600 Subject: [PATCH 02/13] blockwise amax --- tests/cpp/operator/test_cublaslt_gemm.cu | 86 +++++++++++++++--------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index e1e0b9316..0c5f9a759 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -78,51 +78,72 @@ __global__ void compute_ref_kernel( const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; - if (ii >= m || jj >= n) - return; + const bool in_range = (ii < m) && (jj < n); float val = 0.0f; - for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); - float a_scale_inv_val = a_scale_inv_scalar; - float b_scale_inv_val = b_scale_inv_scalar; + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; - if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); - const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; } - const float a_val = a_data[a_idx]; - const float b_val = b_data[b_idx]; + if (bias_data) { + val += static_cast(bias_data[ii]); + } - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; - } + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu(val); + } - if (bias_data) { - val += (float)bias_data[ii]; + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); } - if (gelu_data) { - gelu_data[ii + jj * m] = val; - val = ref_gelu(val); - } + // Blockwise reduction for amax + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; - const float scaled = val * d_scale; - d_data[ii + jj * m] = scaled; + extern __shared__ float s_amax[]; - if (is_fp8_output && d_amax) { - atomicMax(d_amax, fabsf(val)); + // Out-of-range threads contribute 0 + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) { + s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + const float block_max = s_amax[0]; + atomicMax(d_amax, block_max); + } } } @@ -214,8 +235,11 @@ static void compute_ref_impl( dim3 block(16, 16); dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + const int nthreads = block.x * block.y; + size_t shmem_bytes = nthreads * sizeof(float); + compute_ref_kernel - <<>>( + <<>>( dA, dB, a_scale_inv_scalar, From 3ecea7fb11748bc6c99250e3de46ebec68dfc778 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Jan 2026 17:13:48 -0600 Subject: [PATCH 03/13] Change to use Tensor arguments, combine mxfp8/non-mxfp8 paths --- tests/cpp/operator/test_cublaslt_gemm.cu | 343 +++++++++-------------- tests/cpp/test_common.h | 14 +- 2 files changed, 137 insertions(+), 220 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3f5249a6a..631c06c51 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -73,7 +73,9 @@ __global__ void compute_ref_kernel( Gelu_Type* __restrict__ gelu_data, bool transa, bool transb, - bool is_fp8_output) + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise) { const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -84,17 +86,26 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + // Indexing depends on which backing buffer we passed in + const size_t a_idx = + a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + const size_t b_idx = + b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); + a_is_colwise ? (a_idx / 32) + : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + b_is_colwise ? (b_idx / 32) + : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); @@ -147,216 +158,145 @@ __global__ void compute_ref_kernel( } } -// Common implementation used by both tensor-wise and MXFP8 frontends + +struct TestParams { + size_t m; + size_t k; + size_t n; + bool use_bias; + bool use_gelu; + bool transa; + bool transb; + NVTEScalingMode scaling_mode; +}; + + template -static void compute_ref_impl( - const A_Type* a_data, - const B_Type* b_data, - float a_scale_inv_scalar, // used when mxfp8 == false - float b_scale_inv_scalar, - const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true - const fp8e8m0* b_scale_inv_mxfp8, - const Bias_Type* bias_data, - float d_scale, - size_t m, size_t k, size_t n, - D_Type* d_data, - float* d_amax_host, - Gelu_Type* gelu_data, - bool transa, - bool transb) +static void run_reference( + const TestParams& params, + const Tensor& A, + const Tensor& B, + const Tensor* Bias, // nullable + float d_scale, + std::unique_ptr& ref_D, // m*n + float* ref_amax_d, + std::unique_ptr& ref_pre_gelu_out) // nullable { - using transformer_engine::DType; - using ::TypeInfo; - using ::isFp8Type; + const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); - const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); - const DType dtype = TypeInfo::dtype; - const bool is_fp8_output = isFp8Type(dtype); + const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); - const size_t lenA = m * k; - const size_t lenB = k * n; - const size_t lenD = m * n; - const size_t lenBias = m; - const size_t lenGelu = m * n; + const bool a_use_colwise = (!params.transa) && A.columnwise(); + const bool b_use_colwise = ( params.transb) && B.columnwise(); - const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; - const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + const A_Type* a_dev = static_cast( + a_use_colwise ? A.columnwise_dptr() : A.rowwise_dptr()); - A_Type* dA = nullptr; - B_Type* dB = nullptr; - Bias_Type* dBias = nullptr; - D_Type* dD = nullptr; - Gelu_Type* dGelu = nullptr; - float* dAmax = nullptr; - fp8e8m0* dA_scale = nullptr; - fp8e8m0* dB_scale = nullptr; + const B_Type* b_dev = static_cast( + b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // Allocations and H2D transfers - NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + // scaling inputs + float a_scale_inv_scalar = 1.0f; + float b_scale_inv_scalar = 1.0f; - NVTE_CHECK_CUDA(cudaMemcpy( - dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + const fp8e8m0* a_scale_dev = nullptr; + const fp8e8m0* b_scale_dev = nullptr; - if (bias_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); - NVTE_CHECK_CUDA(cudaMemcpy( - dBias, bias_data, lenBias * sizeof(Bias_Type), - cudaMemcpyHostToDevice)); - } + if (use_mxfp8) { + a_scale_dev = params.transa + ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() + : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); - NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + b_scale_dev = params.transb + ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() + : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + } else { + a_scale_inv_scalar = A.rowwise_scale_inv(); + b_scale_inv_scalar = B.rowwise_scale_inv(); } - if (use_mxfp8) { - NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMemcpy( - dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); + // optional bias device pointer + const Bias_Type* bias_dev = nullptr; + if (Bias) { + bias_dev = static_cast(Bias->rowwise_dptr()); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + // allocate device outputs + const size_t lenD = params.m * params.n; + const size_t bytesD = lenD * sizeof(D_Type); + + D_Type* d_refD = nullptr; + Gelu_Type* d_refGelu = nullptr; + float* d_refAmax = nullptr; + + NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + } + if (is_fp8_output && ref_amax_d) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } // Kernel launch dim3 block(16, 16); - dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + dim3 grid((unsigned)((params.n + block.x - 1) / block.x), + (unsigned)((params.m + block.y - 1) / block.y)); - const int nthreads = block.x * block.y; - size_t shmem_bytes = nthreads * sizeof(float); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); compute_ref_kernel <<>>( - dA, - dB, + a_dev, + b_dev, a_scale_inv_scalar, b_scale_inv_scalar, - dA_scale, - dB_scale, - dBias, + a_scale_dev, + b_scale_dev, + bias_dev, d_scale, - m, k, n, - dD, - dAmax, - dGelu, - transa, - transb, - is_fp8_output); + params.m, params.k, params.n, + d_refD, + d_refAmax, + d_refGelu, + params.transa, + params.transb, + is_fp8_output, + a_use_colwise, + b_use_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - // D2H copies - NVTE_CHECK_CUDA(cudaMemcpy( - d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + // copy outputs back + NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMemcpy( - gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), - cudaMemcpyDeviceToHost)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMemcpy( - d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); - } else if (d_amax_host) { - *d_amax_host = 0.0f; + if (ref_amax_d) { + if (is_fp8_output) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), + cudaMemcpyDeviceToHost)); + } else { + *ref_amax_d = 0.0f; + } } // cleanup - NVTE_CHECK_CUDA(cudaFree(dA)); - NVTE_CHECK_CUDA(cudaFree(dB)); - NVTE_CHECK_CUDA(cudaFree(dD)); - if (dBias) - NVTE_CHECK_CUDA(cudaFree(dBias)); - if (dGelu) - NVTE_CHECK_CUDA(cudaFree(dGelu)); - if (dAmax) - NVTE_CHECK_CUDA(cudaFree(dAmax)); - if (dA_scale) - NVTE_CHECK_CUDA(cudaFree(dA_scale)); - if (dB_scale) - NVTE_CHECK_CUDA(cudaFree(dB_scale)); + NVTE_CHECK_CUDA(cudaFree(d_refD)); + if (d_refGelu) + NVTE_CHECK_CUDA(cudaFree(d_refGelu)); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaFree(d_refAmax)); } -template -void compute_ref( - const A_Type* a_data, - const B_Type* b_data, - const float a_scale_inv, - const float b_scale_inv, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/a_scale_inv, - /*b_scale_inv_scalar=*/b_scale_inv, - /*a_scale_inv_mxfp8=*/nullptr, - /*b_scale_inv_mxfp8=*/nullptr, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - -template -void compute_mxfp8_ref( - const A_Type* a_data, - const B_Type* b_data, - const fp8e8m0* a_scale_inv_data, - const fp8e8m0* b_scale_inv_data, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/1.0f, - /*b_scale_inv_scalar=*/1.0f, - /*a_scale_inv_mxfp8=*/a_scale_inv_data, - /*b_scale_inv_mxfp8=*/b_scale_inv_data, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - template void cpu_rowwise_to_columnwise( size_t m, size_t n, @@ -396,16 +336,6 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool return {atol, rtol}; } -struct TestParams { - size_t m; - size_t k; - size_t n; - bool use_bias; - bool use_gelu; - bool transa; - bool transb; - NVTEScalingMode scaling_mode; -}; template void performTest(const TestParams& params) { @@ -588,40 +518,17 @@ void performTest(const TestParams& params) { } float ref_amax_d; - if (use_mxfp8) { - const A_Type *a_data; - const B_Type *b_data; - const fp8e8m0 *a_scale_inv_data, *b_scale_inv_data; - if (params.transa) { - a_data = A.rowwise_cpu_dptr(); - a_scale_inv_data = A.rowwise_cpu_scale_inv_ptr(); - } else { - a_data = A.columnwise_cpu_dptr(); - a_scale_inv_data = A.columnwise_cpu_scale_inv_ptr(); - } - if (params.transb) { - b_data = B.columnwise_cpu_dptr(); - b_scale_inv_data = B.columnwise_cpu_scale_inv_ptr(); - } else { - b_data = B.rowwise_cpu_dptr(); - b_scale_inv_data = B.rowwise_cpu_scale_inv_ptr(); - } - compute_mxfp8_ref( - a_data, b_data, a_scale_inv_data, b_scale_inv_data, - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } else { - compute_ref( - A.rowwise_cpu_dptr(), B.rowwise_cpu_dptr(), - A.rowwise_scale_inv(), B.rowwise_scale_inv(), - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D.scale(), + ref_D, + &ref_amax_d, + ref_pre_gelu_out); + // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index bfb46f8a0..8892ff097 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,6 +224,16 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().data_ptr; + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -244,7 +254,7 @@ class Tensor { } template - T *rowwise_cpu_scale_inv_ptr(){ + T *rowwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { @@ -269,7 +279,7 @@ class Tensor { return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); } - float rowwise_scale_inv(){ + float rowwise_scale_inv() const { if(rowwise_scale_inv_cpu_data_) { float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; return scale_inv; From 86fbbac87113f00341062e4a9b150a855207acd6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 14:17:24 -0600 Subject: [PATCH 04/13] skip on SwizzleScale limitation on gfx950 --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..da59a8dee 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,6 +427,11 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } + + if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { + GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; + } + } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 54de3dbd3891e0a0d0f0962fe3ccc4a9eaac759f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 21:44:57 +0000 Subject: [PATCH 05/13] Revert "skip on SwizzleScale limitation on gfx950" This reverts commit 86fbbac87113f00341062e4a9b150a855207acd6. --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index da59a8dee..560218575 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,11 +427,6 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } - - if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { - GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; - } - } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 311ddfe66bbe738ab550b74dccaf5fb8d885438d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 17:21:58 -0600 Subject: [PATCH 06/13] MXFP8 fix --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 +++----- tests/cpp/test_common.h | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..3d15ac3d4 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -107,11 +107,9 @@ __global__ void compute_ref_kernel( b_is_colwise ? (b_idx / 32) : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) + a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); + b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); } const float a_val = static_cast(a_data[a_idx]); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 8892ff097..2114feacc 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -446,7 +446,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -inline float exp2f_rcp(fp8e8m0 biased_exp) { +__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } From 445e64fbce9060bfe5d0f23dedf5de209bcf353f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 14:14:53 -0600 Subject: [PATCH 07/13] =?UTF-8?q?correct=20scale=5Finv=20packing=20and=20e?= =?UTF-8?q?xp2(biased=E2=88=92127)=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/cpp/operator/test_cublaslt_gemm.cu | 99 ++++++++++++++++++------ tests/cpp/test_common.h | 2 +- 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3d15ac3d4..376a5fc26 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -75,8 +75,10 @@ __global__ void compute_ref_kernel( bool transb, bool is_fp8_output, bool a_is_colwise, - bool b_is_colwise) + bool b_is_colwise, + bool use_mxfp8) { + const size_t k_chunks = k / 32; const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -86,30 +88,33 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - // Indexing depends on which backing buffer we passed in - const size_t a_idx = - a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - - const size_t b_idx = - b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + // Non-MXFP8 FP8 path may use explicit transpose buffers (cpu_rowwise_to_columnwise), + // so indexing depends on which backing buffer is passed in. + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - a_is_colwise ? (a_idx / 32) - : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t kc = kk / 32; - const size_t b_scale_idx = - b_is_colwise ? (b_idx / 32) - : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); + const size_t a_scale_idx = ii * k_chunks + kc; + const size_t b_scale_idx = jj * k_chunks + kc; - // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) - a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); - b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); } const float a_val = static_cast(a_data[a_idx]); @@ -183,6 +188,8 @@ static void run_reference( { const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); + const size_t k_chunks = params.k / 32; + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); @@ -203,14 +210,51 @@ static void run_reference( const fp8e8m0* a_scale_dev = nullptr; const fp8e8m0* b_scale_dev = nullptr; + // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device + std::vector a_scale_packed; + std::vector b_scale_packed; + fp8e8m0* d_a_scale_packed = nullptr; + fp8e8m0* d_b_scale_packed = nullptr; + if (use_mxfp8) { - a_scale_dev = params.transa - ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() - : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); + const fp8e8m0* a_scale_cpu = params.transa + ? A.rowwise_cpu_scale_inv_ptr() + : A.columnwise_cpu_scale_inv_ptr(); + const fp8e8m0* b_scale_cpu = params.transb + ? B.columnwise_cpu_scale_inv_ptr() + : B.rowwise_cpu_scale_inv_ptr(); + + // Pack into row-major [row][kc]: + // A_packed[ii, kc] and B_packed[jj, kc] + a_scale_packed.resize(params.m * k_chunks); + b_scale_packed.resize(params.n * k_chunks); + + for (size_t ii = 0; ii < params.m; ++ii) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); + a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; + } + } + + for (size_t jj = 0; jj < params.n; ++jj) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); + b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; + } + } + + NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); + + NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), + a_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), + b_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); - b_scale_dev = params.transb - ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() - : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + a_scale_dev = d_a_scale_packed; + b_scale_dev = d_b_scale_packed; } else { a_scale_inv_scalar = A.rowwise_scale_inv(); b_scale_inv_scalar = B.rowwise_scale_inv(); @@ -264,7 +308,8 @@ static void run_reference( params.transb, is_fp8_output, a_use_colwise, - b_use_colwise); + b_use_colwise, + use_mxfp8); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); @@ -292,6 +337,10 @@ static void run_reference( NVTE_CHECK_CUDA(cudaFree(d_refGelu)); if (d_refAmax) NVTE_CHECK_CUDA(cudaFree(d_refAmax)); + if (d_a_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); + if (d_b_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 2114feacc..7596bcf06 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -267,7 +267,7 @@ class Tensor { } template - T *columnwise_cpu_scale_inv_ptr(){ + T *columnwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { From 462945fc299deca92a99e783fb1f71f4ae034252 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 15:27:42 -0600 Subject: [PATCH 08/13] cleanups --- tests/cpp/operator/test_cublaslt_gemm.cu | 2 +- tests/cpp/test_common.h | 12 +----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 376a5fc26..21e4d4be6 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -203,7 +203,7 @@ static void run_reference( const B_Type* b_dev = static_cast( b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // scaling inputs + // scaling inputs float a_scale_inv_scalar = 1.0f; float b_scale_inv_scalar = 1.0f; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7596bcf06..07b4cd9bf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,16 +224,6 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } - void *rowwise_scale_inv_dptr() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.get_rowwise_scale_inv().data_ptr; - } - - void *columnwise_scale_inv_dptr() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_scale_inv().data_ptr; - } - float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -446,7 +436,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { +inline float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } From e11e40034c7adc2a0845b4fd66529f9f0929669b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 22 Jan 2026 16:30:38 -0600 Subject: [PATCH 09/13] use Tensor class for more device objects --- tests/cpp/operator/test_cublaslt_gemm.cu | 111 +++++++++-------------- tests/cpp/test_common.h | 10 ++ 2 files changed, 53 insertions(+), 68 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 21e4d4be6..33d3fd85a 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -65,6 +65,10 @@ __global__ void compute_ref_kernel( float b_scale_inv_scalar, const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, const Bias_Type* __restrict__ bias_data, float d_scale, size_t m, size_t k, size_t n, @@ -110,8 +114,10 @@ __global__ void compute_ref_kernel( if (a_scale_inv_mxfp8) { const size_t kc = kk / 32; - const size_t a_scale_idx = ii * k_chunks + kc; - const size_t b_scale_idx = jj * k_chunks + kc; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); @@ -209,52 +215,22 @@ static void run_reference( const fp8e8m0* a_scale_dev = nullptr; const fp8e8m0* b_scale_dev = nullptr; - - // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device - std::vector a_scale_packed; - std::vector b_scale_packed; - fp8e8m0* d_a_scale_packed = nullptr; - fp8e8m0* d_b_scale_packed = nullptr; + size_t a_scale_ld = 0; + size_t b_scale_ld = 0; + bool a_scale_is_colwise = !params.transa; + bool b_scale_is_colwise = params.transb; if (use_mxfp8) { - const fp8e8m0* a_scale_cpu = params.transa - ? A.rowwise_cpu_scale_inv_ptr() - : A.columnwise_cpu_scale_inv_ptr(); - const fp8e8m0* b_scale_cpu = params.transb - ? B.columnwise_cpu_scale_inv_ptr() - : B.rowwise_cpu_scale_inv_ptr(); - - // Pack into row-major [row][kc]: - // A_packed[ii, kc] and B_packed[jj, kc] - a_scale_packed.resize(params.m * k_chunks); - b_scale_packed.resize(params.n * k_chunks); - - for (size_t ii = 0; ii < params.m; ++ii) { - for (size_t kc = 0; kc < k_chunks; ++kc) { - const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); - a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; - } - } - - for (size_t jj = 0; jj < params.n; ++jj) { - for (size_t kc = 0; kc < k_chunks; ++kc) { - const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); - b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; - } - } - - NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); - - NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), - a_scale_packed.size() * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), - b_scale_packed.size() * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - - a_scale_dev = d_a_scale_packed; - b_scale_dev = d_b_scale_packed; + a_scale_dev = static_cast( + a_scale_is_colwise ? A.columnwise_scale_inv_dptr() : A.rowwise_scale_inv_dptr()); + b_scale_dev = static_cast( + b_scale_is_colwise ? B.columnwise_scale_inv_dptr() : B.rowwise_scale_inv_dptr()); + + const NVTEShape a_s = a_scale_is_colwise ? A.columnwise_scale_inv_shape() : A.rowwise_scale_inv_shape(); + const NVTEShape b_s = b_scale_is_colwise ? B.columnwise_scale_inv_shape() : B.rowwise_scale_inv_shape(); + NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv"); + a_scale_ld = a_s.data[1]; + b_scale_ld = b_s.data[1]; } else { a_scale_inv_scalar = A.rowwise_scale_inv(); b_scale_inv_scalar = B.rowwise_scale_inv(); @@ -266,20 +242,25 @@ static void run_reference( bias_dev = static_cast(Bias->rowwise_dptr()); } - // allocate device outputs + // allocate device outputs as test::Tensor objects const size_t lenD = params.m * params.n; const size_t bytesD = lenD * sizeof(D_Type); - D_Type* d_refD = nullptr; + Tensor RefD("RefD", TShape{params.n, params.m}, TypeInfo::dtype); + D_Type* d_refD = static_cast(RefD.rowwise_dptr()); + + Tensor RefGelu; + Tensor RefAmax; Gelu_Type* d_refGelu = nullptr; float* d_refAmax = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); if (ref_gelu_host) { - NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + RefGelu = Tensor("RefGelu", TShape{params.n, params.m}, TypeInfo::dtype); + d_refGelu = static_cast(RefGelu.rowwise_dptr()); } if (is_fp8_output && ref_amax_d) { - NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + RefAmax = Tensor("RefAmax", TShape{1}, DType::kFloat32); + d_refAmax = static_cast(RefAmax.rowwise_dptr()); NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } @@ -298,6 +279,10 @@ static void run_reference( b_scale_inv_scalar, a_scale_dev, b_scale_dev, + a_scale_ld, + b_scale_ld, + a_scale_is_colwise, + b_scale_is_colwise, bias_dev, d_scale, params.m, params.k, params.n, @@ -315,32 +300,22 @@ static void run_reference( NVTE_CHECK_CUDA(cudaDeviceSynchronize()); // copy outputs back - NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); + RefD.to_cpu(); + memcpy(ref_D.get(), RefD.rowwise_cpu_dptr(), bytesD); if (ref_gelu_host) { - NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), - cudaMemcpyDeviceToHost)); + RefGelu.to_cpu(); + memcpy(ref_gelu_host, RefGelu.rowwise_cpu_dptr(), lenD * sizeof(Gelu_Type)); } if (ref_amax_d) { if (is_fp8_output) { - NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), - cudaMemcpyDeviceToHost)); + RefAmax.to_cpu(); + *ref_amax_d = RefAmax.rowwise_cpu_dptr()[0]; } else { *ref_amax_d = 0.0f; } } - - // cleanup - NVTE_CHECK_CUDA(cudaFree(d_refD)); - if (d_refGelu) - NVTE_CHECK_CUDA(cudaFree(d_refGelu)); - if (d_refAmax) - NVTE_CHECK_CUDA(cudaFree(d_refAmax)); - if (d_a_scale_packed) - NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); - if (d_b_scale_packed) - NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e181bce68..17db2a021 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -298,6 +298,16 @@ class Tensor { std::mt19937& gen() { return gen_; } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.scale_inv(); // rowwise scale_inv backing storage + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + private: TensorWrapper tensor_; std::unique_ptr cpu_data_rowwise_; From 325ece611769ceb0af3bb1af26d53838646871ca Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 23 Jan 2026 14:11:48 -0600 Subject: [PATCH 10/13] Pass D Tensor into run_reference and move RefD allocation into PerformTest --- tests/cpp/operator/test_cublaslt_gemm.cu | 77 ++++++++---------------- tests/cpp/test_common.h | 4 ++ 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 33d3fd85a..e1c963734 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -82,7 +82,6 @@ __global__ void compute_ref_kernel( bool b_is_colwise, bool use_mxfp8) { - const size_t k_chunks = k / 32; const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -187,16 +186,13 @@ static void run_reference( const Tensor& A, const Tensor& B, const Tensor* Bias, // nullable - float d_scale, - std::unique_ptr& ref_D, // m*n - float* ref_amax_d, - std::unique_ptr& ref_pre_gelu_out) // nullable + const Tensor& D_for_scale, + Tensor& RefD, + Tensor* RefPreGeluOut) // nullable { const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); - const size_t k_chunks = params.k / 32; - - Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); + const float d_scale = D_for_scale.scale(); const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); @@ -242,26 +238,19 @@ static void run_reference( bias_dev = static_cast(Bias->rowwise_dptr()); } - // allocate device outputs as test::Tensor objects - const size_t lenD = params.m * params.n; - const size_t bytesD = lenD * sizeof(D_Type); - - Tensor RefD("RefD", TShape{params.n, params.m}, TypeInfo::dtype); D_Type* d_refD = static_cast(RefD.rowwise_dptr()); - Tensor RefGelu; - Tensor RefAmax; Gelu_Type* d_refGelu = nullptr; float* d_refAmax = nullptr; - if (ref_gelu_host) { - RefGelu = Tensor("RefGelu", TShape{params.n, params.m}, TypeInfo::dtype); - d_refGelu = static_cast(RefGelu.rowwise_dptr()); + if (RefPreGeluOut) { + d_refGelu = static_cast(RefPreGeluOut->rowwise_dptr()); } - if (is_fp8_output && ref_amax_d) { - RefAmax = Tensor("RefAmax", TShape{1}, DType::kFloat32); - d_refAmax = static_cast(RefAmax.rowwise_dptr()); - NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); + + if (is_fp8_output) { + d_refAmax = static_cast(RefD.amax_dptr()); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } // Kernel launch @@ -297,25 +286,6 @@ static void run_reference( use_mxfp8); NVTE_CHECK_CUDA(cudaGetLastError()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - // copy outputs back - RefD.to_cpu(); - memcpy(ref_D.get(), RefD.rowwise_cpu_dptr(), bytesD); - - if (ref_gelu_host) { - RefGelu.to_cpu(); - memcpy(ref_gelu_host, RefGelu.rowwise_cpu_dptr(), lenD * sizeof(Gelu_Type)); - } - - if (ref_amax_d) { - if (is_fp8_output) { - RefAmax.to_cpu(); - *ref_amax_d = RefAmax.rowwise_cpu_dptr()[0]; - } else { - *ref_amax_d = 0.0f; - } - } } @@ -541,23 +511,21 @@ void performTest(const TestParams& params) { } //perform the reference gemm on GPU - std::unique_ptr ref_D = std::make_unique(params.m*params.n); - std::unique_ptr ref_pre_gelu_out; - if(params.use_gelu){ - ref_pre_gelu_out = std::make_unique(params.m*params.n); - } + Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); + Tensor RefPreGeluOut; - float ref_amax_d; + if (params.use_gelu) { + RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); + } run_reference( params, A, B, params.use_bias ? &bias : nullptr, - D.scale(), - ref_D, - &ref_amax_d, - ref_pre_gelu_out); + D, + RefD, + params.use_gelu ? &RefPreGeluOut : nullptr); // check if error message happens in running (void)cudaDeviceSynchronize(); @@ -567,15 +535,18 @@ void performTest(const TestParams& params) { //compare results auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(dtype)) { + const float ref_amax_d = RefD.amax(); compareResults("D_amax", D.amax(), ref_amax_d, atol_amax, rtol_amax); } auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); - compareResults("D", D, ref_D.get(), true, atol, rtol); + RefD.to_cpu(); + compareResults("D", D, RefD.rowwise_cpu_dptr(), true, atol, rtol); if(params.use_gelu){ auto [atol, rtol] = getTestTolerances(gelu_type, false, false); - compareResults("gelu", pre_gelu_out, ref_pre_gelu_out.get(), true, atol, rtol); + RefPreGeluOut.to_cpu(); + compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr(), true, atol, rtol); } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17db2a021..b824f8d4d 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -233,6 +233,10 @@ class Tensor { } } + void *amax_dptr() const { + return tensor_.amax(); + } + float scale() const { if(scale_cpu_data_) { NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); From fc64b8cec14026905fdaa43be96beb5ce407552c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 26 Jan 2026 17:06:23 -0600 Subject: [PATCH 11/13] [WIP] proof-of-concept: grouped GEMM with ck_tile --- gmm2.py | 62 +++ transformer_engine/common/CMakeLists.txt | 8 + .../common/gemm/ck_grouped_gemm.cuh | 449 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 38 +- 4 files changed, 555 insertions(+), 2 deletions(-) create mode 100644 gmm2.py create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm.cuh diff --git a/gmm2.py b/gmm2.py new file mode 100644 index 000000000..e5bceebbb --- /dev/null +++ b/gmm2.py @@ -0,0 +1,62 @@ +import os, torch +import transformer_engine.pytorch as te +from time import time + +torch.manual_seed(0) + +os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" +os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" + +device = "cuda" +dtype = torch.bfloat16 + +E = 4 +K = 1024 +N = 2048 +m_splits = [128, 64, 0, 256] +M_total = sum(m_splits) + +x = torch.randn(M_total, K, device=device, dtype=dtype) + +# TE +start = time() + +glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) +y_te = glinear(x, m_splits=m_splits) +print("TE time: ", time()-start) + + +Ws = [] +for e in range(E): + w = getattr(glinear, f"weight{e}") # expect [N, K] + Ws.append(w) +W = torch.stack(Ws, dim=0) # [E, N, K] +assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" + + +# Torch +start = time() + +ys = [] +offset = 0 +for e, m in enumerate(m_splits): + if m == 0: + continue + x_e = x[offset:offset+m] # [m, K] + y_e = x_e @ W[e].transpose(0, 1) # [m, N] + ys.append(y_e) + offset += m + +y_ref = torch.cat(ys, dim=0) +print("Torch time:", time()-start) + +# Compare +diff = (y_te.float() - y_ref.float()) +max_abs = diff.abs().max().item() +rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() + +print(f"{y_te.shape=}, {y_ref.shape=}") +print("max_abs_err:", max_abs) +print("max_rel_err:", rel) + +torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..56207f16d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -241,6 +241,14 @@ endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + +target_include_directories(transformer_engine + BEFORE PRIVATE + ${CK_ROOT}/include +) + + if (USE_CUDA) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh new file mode 100644 index 000000000..fa1f1cca1 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -0,0 +1,449 @@ +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +static inline int get_num_cu_for_stream(hipStream_t stream) { + int device = -1; + hipError_t st = hipGetDevice(&device); + if (st != hipSuccess) + return 0; + + hipDeviceProp_t prop{}; + st = hipGetDeviceProperties(&prop, device); + if (st != hipSuccess) + return 0; + + return prop.multiProcessorCount; +} + +// Map TE DType to CK_Tile scalar type +template +struct TeDTypeToCk; + +template <> struct TeDTypeToCk { + using type = ck_tile::half_t; +}; +template <> struct TeDTypeToCk { + using type = ck_tile::bfloat16_t; +}; + +// TE Tensor -> SimpleTensor view +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + // For GEMM we want the "data" view (rowwise) + return t.data; +} + +// CK_Tile runner + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +struct TileCfg_basic { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + // Spatially-local partitioner parameters + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 1; +}; + +template +inline void launch_grouped_kernel(const ck_tile::stream_config& stream_cfg, + ck_tile::index_t group_num, + void* args_ptr, + uint32_t num_cu) { + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg); + grids.x = std::min(grids.x, static_cast(num_cu)); + ck_tile::launch_kernel( + stream_cfg, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(args_ptr), + group_num)); +} + +template +class Runner{ +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TileParitionerGroupNum, TileCfg::TileParitionerM01>; + + using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + static constexpr ck_tile::memory_operation_enum MemOp = ck_tile::memory_operation_enum::set; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + void run(const ck_tile::stream_config& stream_cfg, + ck_tile::index_t group_num, + void* args_ptr, + uint32_t num_cu) { + launch_grouped_kernel(stream_cfg, group_num, args_ptr, num_cu); + } +}; + +// Arg builder kernel + +template +__global__ void build_args_kernel(ck_tile::GemmTransKernelArg<>* args, + const void* const* a_ptrs, + const void* const* b_ptrs, + void* const* d_ptrs, + const int64_t* ms, + const int64_t* ns, + const int64_t* ks, + ck_tile::index_t group_num, + ck_tile::index_t strideA, + ck_tile::index_t strideB, + ck_tile::index_t strideD, + ck_tile::index_t k_batch) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= group_num) + return; + + // CK_Tile's grouped arg uses arrays for As/Bs + const_cast&>(args[gid].group_karg.as_ptr)[0] = + static_cast(a_ptrs[gid]); + const_cast&>(args[gid].group_karg.bs_ptr)[0] = + static_cast(b_ptrs[gid]); + + args[gid].group_karg.e_ptr = d_ptrs[gid]; + + args[gid].group_karg.M = static_cast(ms[gid]); + args[gid].group_karg.N = static_cast(ns[gid]); + args[gid].group_karg.K = static_cast(ks[gid]); + + args[gid].group_karg.stride_As[0] = strideA; + args[gid].group_karg.stride_Bs[0] = strideB; + args[gid].group_karg.stride_E = strideD; + args[gid].group_karg.k_batch = k_batch; +} + +bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + bool transA, + bool transB, + void* workspace, + size_t workspace_bytes, + hipStream_t stream, + uint32_t num_cu_override = 0) { + // TE sometimes passes (A=weight, B=input, transA=1, transB=0) for y = x * W^T + // CK_Tile expects the left operand to be the activation matrix + // So for (transA && !transB), swap A/B and turn it into (!transA && transB) + const transformer_engine::Tensor* const* A_use = A; + const transformer_engine::Tensor* const* B_use = B; + bool transA_use = transA; + bool transB_use = transB; + if (transA && !transB) { + A_use = B; + B_use = A; + transA_use = false; + transB_use = true; + } + + if (!( (!transA_use && !transB_use) || (!transA_use && transB_use) )) { + NVTE_ERROR("grouped_gemm_ck_tile: only NN/NT/TN supported."); + return false; + } + + // DType routing: allow fp16/bf16 for now + const auto a_dtype = A_use[0]->dtype(); + const auto b_dtype = B_use[0]->dtype(); + const auto d_dtype = D[0]->dtype(); + if (a_dtype != b_dtype || a_dtype != d_dtype) { + NVTE_ERROR("grouped_gemm_ck_tile: dtype mismatch A/B/D."); + return false; + } + if (!(a_dtype == transformer_engine::DType::kFloat16 || + a_dtype == transformer_engine::DType::kBFloat16)) { + NVTE_ERROR("grouped_gemm_ck_tile: only fp16/bf16 supported."); + return false; + } + + // Workspace layout: + // [0] device arrays of pointers (A_ptrs, B_ptrs, D_ptrs) + // [1] device arrays of int64 (M, N, K) + // [2] ck_tile::GemmTransKernelArg<>[group_num] + const size_t ptr_arr_bytes = sizeof(void*) * static_cast(group_num); + const size_t i64_arr_bytes = sizeof(int64_t) * static_cast(group_num); + + const size_t off_a_ptrs = 0; + const size_t off_b_ptrs = off_a_ptrs + ptr_arr_bytes; + const size_t off_d_ptrs = off_b_ptrs + ptr_arr_bytes; + const size_t off_ms = off_d_ptrs + ptr_arr_bytes; + const size_t off_ns = off_ms + i64_arr_bytes; + const size_t off_ks = off_ns + i64_arr_bytes; + + const size_t off_args = ck_tile::integer_divide_ceil(off_ks + i64_arr_bytes, size_t(16)) * 16; + + const size_t args_bytes = sizeof(ck_tile::GemmTransKernelArg<>) * static_cast(group_num); + const size_t needed = off_args + args_bytes; + + if (workspace == nullptr || workspace_bytes < needed) { + NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + return false; + } + + auto* base = static_cast(workspace); + + void** d_a_ptrs = reinterpret_cast(base + off_a_ptrs); + void** d_b_ptrs = reinterpret_cast(base + off_b_ptrs); + void** d_d_ptrs = reinterpret_cast(base + off_d_ptrs); + int64_t* d_ms = reinterpret_cast(base + off_ms); + int64_t* d_ns = reinterpret_cast(base + off_ns); + int64_t* d_ks = reinterpret_cast(base + off_ks); + + auto* d_args = reinterpret_cast*>(base + off_args); + + // Build host-side staging buffers and memcpy to device + std::vector h_a_ptrs(group_num); + std::vector h_b_ptrs(group_num); + std::vector h_d_ptrs(group_num); + std::vector h_ms(group_num); + std::vector h_ns(group_num); + std::vector h_ks(group_num); + + // Infer global N/K from group 0 + const auto& a0 = data_view(*A_use[0]); + const auto& b0 = data_view(*B_use[0]); + const auto& d0 = data_view(*D[0]); + if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected 2D tensors."); + return false; + } + + printf("grouped_gemm_ck_tile gg0 A=[%zu,%zu] B=[%zu,%zu] D=[%zu,%zu] transA=%d transB=%d\n", + a0.shape[0], a0.shape[1], + b0.shape[0], b0.shape[1], + d0.shape[0], d0.shape[1], + (int)transA_use, (int)transB_use); + + // Infer logical M/K from A depending on transA + // - NN/NT: A stored [M,K] + // - TN: A stored [K,M] row-major, interpret as ColMajor [M,K] + const int64_t m0 = transA_use ? static_cast(a0.shape[1]) : static_cast(a0.shape[0]); + const int64_t k0 = transA_use ? static_cast(a0.shape[0]) : static_cast(a0.shape[1]); + + const int64_t n0 = transB_use ? static_cast(b0.shape[0]) + : static_cast(b0.shape[1]); + const int64_t kb = transB_use ? static_cast(b0.shape[1]) + : static_cast(b0.shape[0]); + if (kb != k0) { + NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group 0."); + return false; + } + if (static_cast(d0.shape[0]) != m0 || static_cast(d0.shape[1]) != n0) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group 0."); + return false; + } + + for (int i = 0; i < group_num; ++i) { + const auto& ai = data_view(*A_use[i]); + const auto& bi = data_view(*B_use[i]); + const auto& di = data_view(*D[i]); + + if (ai.shape.size() != 2 || bi.shape.size() != 2 || di.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + return false; + } + + const int64_t mi = transA_use ? static_cast(ai.shape[1]) : static_cast(ai.shape[0]); + const int64_t ki = transA_use ? static_cast(ai.shape[0]) : static_cast(ai.shape[1]); + const int64_t ni = transB_use ? static_cast(bi.shape[0]) + : static_cast(bi.shape[1]); + const int64_t kbi = transB_use ? static_cast(bi.shape[1]) + : static_cast(bi.shape[0]); + + if (ki != k0 || ni != n0 || kbi != k0) { + NVTE_ERROR("grouped_gemm_ck_tile: N/K must be constant across groups."); + return false; + } + if (static_cast(di.shape[0]) != mi || static_cast(di.shape[1]) != n0) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + return false; + } + + h_a_ptrs[i] = ai.dptr; + h_b_ptrs[i] = bi.dptr; + h_d_ptrs[i] = di.dptr; + h_ms[i] = mi; + h_ns[i] = n0; + h_ks[i] = k0; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(d_a_ptrs, h_a_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_b_ptrs, h_b_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_d_ptrs, h_d_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ms, h_ms.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ns, h_ns.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ks, h_ks.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + + // Leading dimensions for CK layouts: + // A is row-major [M,K] and we only support transA=false -> ALayout=RowMajor, strideA=K + // B is row-major [K,N] if NN -> BLayout=RowMajor, strideB=N + // B is row-major [N,K] if NT -> BLayout=ColMajor (logical [K,N]), strideB=K + const ck_tile::index_t strideA = static_cast(transA_use ? m0 : k0); + const ck_tile::index_t strideB = static_cast(transB_use ? k0 : n0); + const ck_tile::index_t strideD = static_cast(n0); + + // Build CK arg structs on device + { + const int threads = 256; + const int blocks = (group_num + threads - 1) / threads; + const ck_tile::index_t k_batch = 1; + if (a_dtype == transformer_engine::DType::kFloat16) { + using AType = TeDTypeToCk::type; + using BType = AType; + using CType = AType; + hipLaunchKernelGGL((build_args_kernel), + dim3(blocks), dim3(threads), 0, + reinterpret_cast(stream), + d_args, + const_cast(reinterpret_cast(d_a_ptrs)), + const_cast(reinterpret_cast(d_b_ptrs)), + reinterpret_cast(d_d_ptrs), + d_ms, d_ns, d_ks, + static_cast(group_num), + strideA, strideB, strideD, + k_batch); + } else { + using AType = TeDTypeToCk::type; + using BType = AType; + using CType = AType; + hipLaunchKernelGGL((build_args_kernel), + dim3(blocks), dim3(threads), 0, + reinterpret_cast(stream), + d_args, + const_cast(reinterpret_cast(d_a_ptrs)), + const_cast(reinterpret_cast(d_b_ptrs)), + reinterpret_cast(d_d_ptrs), + d_ms, d_ns, d_ks, + static_cast(group_num), + strideA, strideB, strideD, + k_batch); + } + } + + // Runner selection + const uint32_t num_cu = (num_cu_override != 0) ? num_cu_override + : static_cast(get_num_cu_for_stream(stream)); + const ck_tile::stream_config stream_cfg{reinterpret_cast(stream)}; + + // Choose layouts based on transB + if (a_dtype == transformer_engine::DType::kFloat16) { + using T = TeDTypeToCk::type; + + if (!transB_use) { + // NN: A RowMajor, B RowMajor, D RowMajor + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } else { + // NT: B is stored as [N,K] row-major -> treat as ColMajor logical [K,N] + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } + } else { + using T = TeDTypeToCk::type; + + if (!transB_use) { + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } else { + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } + } + + return true; +} + +bool grouped_gemm_ck_tile(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + hipStream_t stream) { + if (group_num <= 0) + return true; + + // Convert A/B/D arrays into TE Tensor* arrays + std::vector A_te(group_num); + std::vector B_te(group_num); + std::vector D_te(group_num); + + for (int i = 0; i < group_num; ++i) { + A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); + B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); + D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); + } + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * transformer_engine::typeToSize(ws_te->data.dtype); + } + + return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), + group_num, transA, transB, + ws_ptr, ws_bytes, + stream); +} diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..e583bc14f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -24,8 +24,11 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" +#else +#include "ck_grouped_gemm.cuh" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -788,7 +791,38 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + const bool use_ck = transformer_engine::getenv("NVTE_USE_CK_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; + auto D_dt = OutputD->data.dtype; + + return (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16); + }; + + if (use_ck && + is_supported_dtype() && + !accumulate) { + + if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, stream)) { + printf("grouped_gemm_ck_tile done.\n"); + return; + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); + } + } + + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported).\n"); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else const int current_device = transformer_engine::cuda::current_device(); From 9091e6ce73ea47b0436edc0af6f394465c2bd1cd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 29 Jan 2026 10:52:22 -0600 Subject: [PATCH 12/13] restructure and enable tests --- gmm2.py | 90 ++- tests/pytorch/test_numerics.py | 16 +- .../common/gemm/ck_grouped_gemm.cuh | 549 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 13 +- 4 files changed, 319 insertions(+), 349 deletions(-) diff --git a/gmm2.py b/gmm2.py index e5bceebbb..016304d9f 100644 --- a/gmm2.py +++ b/gmm2.py @@ -1,13 +1,14 @@ -import os, torch +import os +import time +import torch import transformer_engine.pytorch as te -from time import time torch.manual_seed(0) os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" -device = "cuda" +device = "cuda" dtype = torch.bfloat16 E = 4 @@ -18,45 +19,74 @@ x = torch.randn(M_total, K, device=device, dtype=dtype) -# TE -start = time() - +# Timing helper +def bench_cuda(fn, warmup=20, iters=100, name=""): + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + # Timed + start = time.time() + for _ in range(iters): + fn() + torch.cuda.synchronize() + end = time.time() + + avg_ms = (end - start) * 1000.0 / iters + if name: + print(f"{name}: {avg_ms:.3f} ms (avg over {iters} runs, {warmup} warmup)") + return avg_ms + +# TE GroupedLinear glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) -y_te = glinear(x, m_splits=m_splits) -print("TE time: ", time()-start) +def te_run(): + return glinear(x, m_splits=m_splits) + +te_ms = bench_cuda(te_run, warmup=20, iters=100, name="TE GroupedLinear") -Ws = [] -for e in range(E): - w = getattr(glinear, f"weight{e}") # expect [N, K] - Ws.append(w) -W = torch.stack(Ws, dim=0) # [E, N, K] +# Grab weights for reference path +Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] +W = torch.stack(Ws, dim=0) # [E, N, K] assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" +# Torch reference (group loop) +offsets = [] +off = 0 +for m in m_splits: + offsets.append(off) + off += m -# Torch -start = time() +y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype) -ys = [] -offset = 0 -for e, m in enumerate(m_splits): - if m == 0: - continue - x_e = x[offset:offset+m] # [m, K] - y_e = x_e @ W[e].transpose(0, 1) # [m, N] - ys.append(y_e) - offset += m +def torch_run(): + # Fill the preallocated buffer + for e, m in enumerate(m_splits): + if m == 0: + continue + o = offsets[e] + y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) + return y_ref_buf -y_ref = torch.cat(ys, dim=0) -print("Torch time:", time()-start) +torch_ms = bench_cuda(torch_run, warmup=20, iters=100, name="Torch loop (prealloc out)") + +# Compare outputs +y_te = te_run() +y_ref = torch_run().clone() -# Compare diff = (y_te.float() - y_ref.float()) max_abs = diff.abs().max().item() rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() -print(f"{y_te.shape=}, {y_ref.shape=}") -print("max_abs_err:", max_abs) -print("max_rel_err:", rel) +print(f"\nErrors:") +print(f" {y_te.shape=}, {y_ref.shape=}") +print(" max_abs_err:", max_abs) +print(" max_rel_err:", rel) torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) + +print(f"\nTiming:") +print(f" TE avg: {te_ms:.3f} ms") +print(f" Torch avg: {torch_ms:.3f} ms") +print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a4dfd64ba..5f1489f88 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1385,7 +1385,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") te_linear_ref = Linear( config.hidden_size, @@ -1677,7 +1677,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ): if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") config = model_configs[model] ln_linear_ref = LayerNormLinear( @@ -1891,7 +1891,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, @@ -2036,7 +2036,7 @@ def test_grouped_linear_accuracy( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") @@ -2115,7 +2115,7 @@ def test_grouped_linear_accuracy( @pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, reason="Only enable CUTLASS grouped gemm on Hopper", ) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2133,6 +2133,9 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + if IS_HIP_EXTENSION: + os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" + os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2147,6 +2150,9 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + if IS_HIP_EXTENSION: + os.environ.pop("NVTE_USE_CK_GROUPED_GEMM", None) + os.environ.pop("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", None) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh index fa1f1cca1..1171e33f9 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -1,45 +1,23 @@ +/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ + #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -static inline int get_num_cu_for_stream(hipStream_t stream) { - int device = -1; - hipError_t st = hipGetDevice(&device); - if (st != hipSuccess) - return 0; - - hipDeviceProp_t prop{}; - st = hipGetDeviceProperties(&prop, device); - if (st != hipSuccess) - return 0; - - return prop.multiProcessorCount; -} +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -// Map TE DType to CK_Tile scalar type template struct TeDTypeToCk; +template <> struct TeDTypeToCk { using type = ck_tile::half_t; }; +template <> struct TeDTypeToCk{ using type = ck_tile::bfloat16_t; }; -template <> struct TeDTypeToCk { - using type = ck_tile::half_t; -}; -template <> struct TeDTypeToCk { - using type = ck_tile::bfloat16_t; -}; - -// TE Tensor -> SimpleTensor view static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - // For GEMM we want the "data" view (rowwise) - return t.data; + return t.data; // rowwise data view } -// CK_Tile runner - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - struct TileCfg_basic { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 128; @@ -59,31 +37,14 @@ struct TileCfg_basic { static constexpr bool DoubleSmemBuffer = false; - // Spatially-local partitioner parameters - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 1; + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 1; }; -template -inline void launch_grouped_kernel(const ck_tile::stream_config& stream_cfg, - ck_tile::index_t group_num, - void* args_ptr, - uint32_t num_cu) { - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg); - grids.x = std::min(grids.x, static_cast(num_cu)); - ck_tile::launch_kernel( - stream_cfg, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(args_ptr), - group_num)); -} - template + typename TileCfg, ck_tile::memory_operation_enum MemOp, + typename AccType = float> class Runner{ public: using GemmShape = ck_tile::TileGemmShape< @@ -92,7 +53,7 @@ public: ck_tile::sequence>; using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TileParitionerGroupNum, TileCfg::TileParitionerM01>; + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, @@ -106,8 +67,6 @@ public: using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - static constexpr ck_tile::memory_operation_enum MemOp = ck_tile::memory_operation_enum::set; - using Epilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem< AType, BType, ck_tile::tuple<>, AccType, @@ -119,296 +78,272 @@ public: Problem::TransposeC, MemOp>>; using Kernel = ck_tile::GroupedGemmKernel; - - void run(const ck_tile::stream_config& stream_cfg, - ck_tile::index_t group_num, - void* args_ptr, - uint32_t num_cu) { - launch_grouped_kernel(stream_cfg, group_num, args_ptr, num_cu); - } }; -// Arg builder kernel - -template -__global__ void build_args_kernel(ck_tile::GemmTransKernelArg<>* args, - const void* const* a_ptrs, - const void* const* b_ptrs, - void* const* d_ptrs, - const int64_t* ms, - const int64_t* ns, - const int64_t* ks, - ck_tile::index_t group_num, - ck_tile::index_t strideA, - ck_tile::index_t strideB, - ck_tile::index_t strideD, - ck_tile::index_t k_batch) { - const int gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= group_num) - return; - - // CK_Tile's grouped arg uses arrays for As/Bs - const_cast&>(args[gid].group_karg.as_ptr)[0] = - static_cast(a_ptrs[gid]); - const_cast&>(args[gid].group_karg.bs_ptr)[0] = - static_cast(b_ptrs[gid]); - - args[gid].group_karg.e_ptr = d_ptrs[gid]; - - args[gid].group_karg.M = static_cast(ms[gid]); - args[gid].group_karg.N = static_cast(ns[gid]); - args[gid].group_karg.K = static_cast(ks[gid]); - - args[gid].group_karg.stride_As[0] = strideA; - args[gid].group_karg.stride_Bs[0] = strideB; - args[gid].group_karg.stride_E = strideD; - args[gid].group_karg.k_batch = k_batch; -} +template +static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, + ck_tile::index_t group_num, + void* kargs_dev) +{ + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); -bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, - const transformer_engine::Tensor* const* B, - transformer_engine::Tensor* const* D, - int group_num, - bool transA, - bool transB, - void* workspace, - size_t workspace_bytes, - hipStream_t stream, - uint32_t num_cu_override = 0) { - // TE sometimes passes (A=weight, B=input, transA=1, transB=0) for y = x * W^T - // CK_Tile expects the left operand to be the activation matrix - // So for (transA && !transB), swap A/B and turn it into (!transA && transB) - const transformer_engine::Tensor* const* A_use = A; - const transformer_engine::Tensor* const* B_use = B; - bool transA_use = transA; - bool transB_use = transB; - if (transA && !transB) { - A_use = B; - B_use = A; - transA_use = false; - transB_use = true; - } + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_dev), + group_num)); +} - if (!( (!transA_use && !transB_use) || (!transA_use && transB_use) )) { - NVTE_ERROR("grouped_gemm_ck_tile: only NN/NT/TN supported."); +template +static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, + const transformer_engine::Tensor* const* B_use, + transformer_engine::Tensor* const* D, + int group_num, + bool transA_use, + bool transB_use, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) +{ + using R = Runner; + using Kernel = typename R::Kernel; + + const size_t needed = Kernel::GetWorkSpaceSize(group_num); + if (!workspace || workspace_bytes < needed) { + NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); return false; } - // DType routing: allow fp16/bf16 for now - const auto a_dtype = A_use[0]->dtype(); - const auto b_dtype = B_use[0]->dtype(); - const auto d_dtype = D[0]->dtype(); - if (a_dtype != b_dtype || a_dtype != d_dtype) { - NVTE_ERROR("grouped_gemm_ck_tile: dtype mismatch A/B/D."); - return false; - } - if (!(a_dtype == transformer_engine::DType::kFloat16 || - a_dtype == transformer_engine::DType::kBFloat16)) { - NVTE_ERROR("grouped_gemm_ck_tile: only fp16/bf16 supported."); - return false; - } + std::vector> descs; + descs.reserve(group_num); + + for (int i = 0; i < group_num; ++i) { + const auto& a = data_view(*A_use[i]); + const auto& b = data_view(*B_use[i]); + const auto& d = data_view(*D[i]); + + if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + return false; + } - // Workspace layout: - // [0] device arrays of pointers (A_ptrs, B_ptrs, D_ptrs) - // [1] device arrays of int64 (M, N, K) - // [2] ck_tile::GemmTransKernelArg<>[group_num] - const size_t ptr_arr_bytes = sizeof(void*) * static_cast(group_num); - const size_t i64_arr_bytes = sizeof(int64_t) * static_cast(group_num); + const int64_t Ad0 = a.shape[0]; + const int64_t Ad1 = a.shape[1]; + const int64_t Bd0 = b.shape[0]; + const int64_t Bd1 = b.shape[1]; - const size_t off_a_ptrs = 0; - const size_t off_b_ptrs = off_a_ptrs + ptr_arr_bytes; - const size_t off_d_ptrs = off_b_ptrs + ptr_arr_bytes; - const size_t off_ms = off_d_ptrs + ptr_arr_bytes; - const size_t off_ns = off_ms + i64_arr_bytes; - const size_t off_ks = off_ns + i64_arr_bytes; + const int64_t M = transA_use ? Ad1 : Ad0; + const int64_t K = transA_use ? Ad0 : Ad1; + const int64_t N = transB_use ? Bd0 : Bd1; + const int64_t Kb = transB_use ? Bd1 : Bd0; - const size_t off_args = ck_tile::integer_divide_ceil(off_ks + i64_arr_bytes, size_t(16)) * 16; + if (Kb != K) { + NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group ", i); + return false; + } - const size_t args_bytes = sizeof(ck_tile::GemmTransKernelArg<>) * static_cast(group_num); - const size_t needed = off_args + args_bytes; + if (d.shape[0] != M || d.shape[1] != N) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + return false; + } - if (workspace == nullptr || workspace_bytes < needed) { - NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + const ck_tile::index_t stride_A = a.shape[1]; + const ck_tile::index_t stride_B = b.shape[1]; + const ck_tile::index_t stride_E = d.shape[1]; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); return false; } - auto* base = static_cast(workspace); + HIP_CHECK_ERROR(hipMemcpyAsync(workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); - void** d_a_ptrs = reinterpret_cast(base + off_a_ptrs); - void** d_b_ptrs = reinterpret_cast(base + off_b_ptrs); - void** d_d_ptrs = reinterpret_cast(base + off_d_ptrs); - int64_t* d_ms = reinterpret_cast(base + off_ms); - int64_t* d_ns = reinterpret_cast(base + off_ns); - int64_t* d_ks = reinterpret_cast(base + off_ks); + const ck_tile::stream_config s{stream}; + launch_tileloop_kernel(s, group_num, workspace); + return true; +} - auto* d_args = reinterpret_cast*>(base + off_args); +static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + const transformer_engine::Tensor* const*& A_use, + const transformer_engine::Tensor* const*& B_use, + bool& transA_use, + bool& transB_use) +{ + A_use = A; + B_use = B; + transA_use = false; + transB_use = false; - // Build host-side staging buffers and memcpy to device - std::vector h_a_ptrs(group_num); - std::vector h_b_ptrs(group_num); - std::vector h_d_ptrs(group_num); - std::vector h_ms(group_num); - std::vector h_ns(group_num); - std::vector h_ks(group_num); + if (group_num <= 0) + return true; - // Infer global N/K from group 0 - const auto& a0 = data_view(*A_use[0]); - const auto& b0 = data_view(*B_use[0]); + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); const auto& d0 = data_view(*D[0]); - if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { - NVTE_ERROR("grouped_gemm_ck_tile: expected 2D tensors."); - return false; - } - printf("grouped_gemm_ck_tile gg0 A=[%zu,%zu] B=[%zu,%zu] D=[%zu,%zu] transA=%d transB=%d\n", - a0.shape[0], a0.shape[1], - b0.shape[0], b0.shape[1], - d0.shape[0], d0.shape[1], - (int)transA_use, (int)transB_use); - - // Infer logical M/K from A depending on transA - // - NN/NT: A stored [M,K] - // - TN: A stored [K,M] row-major, interpret as ColMajor [M,K] - const int64_t m0 = transA_use ? static_cast(a0.shape[1]) : static_cast(a0.shape[0]); - const int64_t k0 = transA_use ? static_cast(a0.shape[0]) : static_cast(a0.shape[1]); - - const int64_t n0 = transB_use ? static_cast(b0.shape[0]) - : static_cast(b0.shape[1]); - const int64_t kb = transB_use ? static_cast(b0.shape[1]) - : static_cast(b0.shape[0]); - if (kb != k0) { - NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group 0."); - return false; - } - if (static_cast(d0.shape[0]) != m0 || static_cast(d0.shape[1]) != n0) { - NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group 0."); + if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { return false; } - for (int i = 0; i < group_num; ++i) { - const auto& ai = data_view(*A_use[i]); - const auto& bi = data_view(*B_use[i]); - const auto& di = data_view(*D[i]); - - if (ai.shape.size() != 2 || bi.shape.size() != 2 || di.shape.size() != 2) { - NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); - return false; + const int64_t Ad0 = a0.shape[0]; + const int64_t Ad1 = a0.shape[1]; + const int64_t Bd0 = b0.shape[0]; + const int64_t Bd1 = b0.shape[1]; + const int64_t Dm = d0.shape[0]; + const int64_t Dn = d0.shape[1]; + + auto check = [&](bool do_swap, bool ta, bool tb) -> bool { + const int64_t A0d0 = do_swap ? Bd0 : Ad0; + const int64_t A0d1 = do_swap ? Bd1 : Ad1; + const int64_t B0d0 = do_swap ? Ad0 : Bd0; + const int64_t B0d1 = do_swap ? Ad1 : Bd1; + + const int64_t M = ta ? A0d1 : A0d0; + const int64_t K = ta ? A0d0 : A0d1; + const int64_t N = tb ? B0d0 : B0d1; + const int64_t Kb = tb ? B0d1 : B0d0; + + return (M == Dm) && (N == Dn) && (K == Kb); + }; + + // Try all candidates; prefer "no swap" first, then swap. + for (bool do_swap : {false, true}) { + for (bool ta : {false, true}) { + for (bool tb : {false, true}) { + if (check(do_swap, ta, tb)) { + A_use = do_swap ? B : A; + B_use = do_swap ? A : B; + transA_use = ta; + transB_use = tb; + return true; + } + } } + } - const int64_t mi = transA_use ? static_cast(ai.shape[1]) : static_cast(ai.shape[0]); - const int64_t ki = transA_use ? static_cast(ai.shape[0]) : static_cast(ai.shape[1]); - const int64_t ni = transB_use ? static_cast(bi.shape[0]) - : static_cast(bi.shape[1]); - const int64_t kbi = transB_use ? static_cast(bi.shape[1]) - : static_cast(bi.shape[0]); + // Nothing matched D = op(A) * op(B) + return false; +} - if (ki != k0 || ni != n0 || kbi != k0) { - NVTE_ERROR("grouped_gemm_ck_tile: N/K must be constant across groups."); - return false; - } - if (static_cast(di.shape[0]) != mi || static_cast(di.shape[1]) != n0) { - NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); - return false; - } +bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + bool transA, + bool transB, + void* workspace, + size_t workspace_bytes, + bool accumulate, + hipStream_t stream) +{ + const transformer_engine::Tensor* const* A_use = A; + const transformer_engine::Tensor* const* B_use = B; + bool transA_use = transA; + bool transB_use = transB; - h_a_ptrs[i] = ai.dptr; - h_b_ptrs[i] = bi.dptr; - h_d_ptrs[i] = di.dptr; - h_ms[i] = mi; - h_ns[i] = n0; - h_ks[i] = k0; + // If TE's flags disagree with storage, infer the correct mode from shapes. + if (!infer_gemm_mode_group0(A, B, D, group_num, A_use, B_use, transA_use, transB_use)) { + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); + const auto& d0 = data_view(*D[0]); + NVTE_ERROR("grouped_gemm_ck_tile: could not infer a consistent GEMM mode from shapes. ", + "A0=[", a0.shape[0], ",", a0.shape[1], "] ", + "B0=[", b0.shape[0], ",", b0.shape[1], "] ", + "D0=[", d0.shape[0], ",", d0.shape[1], "] ", + "given flags transA=", transA, " transB=", transB); + return false; } - HIP_CHECK_ERROR(hipMemcpyAsync(d_a_ptrs, h_a_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_b_ptrs, h_b_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_d_ptrs, h_d_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ms, h_ms.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ns, h_ns.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ks, h_ks.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - - // Leading dimensions for CK layouts: - // A is row-major [M,K] and we only support transA=false -> ALayout=RowMajor, strideA=K - // B is row-major [K,N] if NN -> BLayout=RowMajor, strideB=N - // B is row-major [N,K] if NT -> BLayout=ColMajor (logical [K,N]), strideB=K - const ck_tile::index_t strideA = static_cast(transA_use ? m0 : k0); - const ck_tile::index_t strideB = static_cast(transB_use ? k0 : n0); - const ck_tile::index_t strideD = static_cast(n0); - - // Build CK arg structs on device - { - const int threads = 256; - const int blocks = (group_num + threads - 1) / threads; - const ck_tile::index_t k_batch = 1; - if (a_dtype == transformer_engine::DType::kFloat16) { - using AType = TeDTypeToCk::type; - using BType = AType; - using CType = AType; - hipLaunchKernelGGL((build_args_kernel), - dim3(blocks), dim3(threads), 0, - reinterpret_cast(stream), - d_args, - const_cast(reinterpret_cast(d_a_ptrs)), - const_cast(reinterpret_cast(d_b_ptrs)), - reinterpret_cast(d_d_ptrs), - d_ms, d_ns, d_ks, - static_cast(group_num), - strideA, strideB, strideD, - k_batch); - } else { - using AType = TeDTypeToCk::type; - using BType = AType; - using CType = AType; - hipLaunchKernelGGL((build_args_kernel), - dim3(blocks), dim3(threads), 0, - reinterpret_cast(stream), - d_args, - const_cast(reinterpret_cast(d_a_ptrs)), - const_cast(reinterpret_cast(d_b_ptrs)), - reinterpret_cast(d_d_ptrs), - d_ms, d_ns, d_ks, - static_cast(group_num), - strideA, strideB, strideD, - k_batch); - } - } + const auto a_dtype = A_use[0]->dtype(); - // Runner selection - const uint32_t num_cu = (num_cu_override != 0) ? num_cu_override - : static_cast(get_num_cu_for_stream(stream)); - const ck_tile::stream_config stream_cfg{reinterpret_cast(stream)}; + const auto memop = accumulate ? ck_tile::memory_operation_enum::atomic_add + : ck_tile::memory_operation_enum::set; - // Choose layouts based on transB if (a_dtype == transformer_engine::DType::kFloat16) { using T = TeDTypeToCk::type; - if (!transB_use) { - // NN: A RowMajor, B RowMajor, D RowMajor - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } else { - // NT: B is stored as [N,K] row-major -> treat as ColMajor logical [K,N] - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); } else { using T = TeDTypeToCk::type; - if (!transB_use) { - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } else { - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); } - - return true; } bool grouped_gemm_ck_tile(const NVTETensor* A, @@ -418,7 +353,9 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, bool transA, bool transB, NVTETensor* workspace, - hipStream_t stream) { + bool accumulate, + hipStream_t stream) +{ if (group_num <= 0) return true; @@ -444,6 +381,6 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), group_num, transA, transB, - ws_ptr, ws_bytes, + ws_ptr, ws_bytes, accumulate, stream); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index e583bc14f..fcbdac91c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -808,20 +808,17 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor A_dt == transformer_engine::DType::kBFloat16); }; - if (use_ck && - is_supported_dtype() && - !accumulate) { - - if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, stream)) { - printf("grouped_gemm_ck_tile done.\n"); + if (use_ck && is_supported_dtype()) { + if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + // NVTE_WARN("grouped_gemm_ck_tile done.\n"); return; } else if (warn_fallback) { NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); } + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported or CK disabled). use_ck=", use_ck, " is_supported_dtype=", is_supported_dtype()); } - NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported).\n"); - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else From 4e9ead9a5a8de6266a44e873ef01d6bf6a147e61 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 30 Jan 2026 14:47:30 -0600 Subject: [PATCH 13/13] grid improvements --- gmm2.py | 10 ++++------ transformer_engine/common/gemm/ck_grouped_gemm.cuh | 5 +++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/gmm2.py b/gmm2.py index 016304d9f..8966afa02 100644 --- a/gmm2.py +++ b/gmm2.py @@ -20,7 +20,7 @@ x = torch.randn(M_total, K, device=device, dtype=dtype) # Timing helper -def bench_cuda(fn, warmup=20, iters=100, name=""): +def bench_cuda(fn, warmup=20, iters=100): # Warmup for _ in range(warmup): fn() @@ -34,8 +34,6 @@ def bench_cuda(fn, warmup=20, iters=100, name=""): end = time.time() avg_ms = (end - start) * 1000.0 / iters - if name: - print(f"{name}: {avg_ms:.3f} ms (avg over {iters} runs, {warmup} warmup)") return avg_ms # TE GroupedLinear @@ -44,7 +42,7 @@ def bench_cuda(fn, warmup=20, iters=100, name=""): def te_run(): return glinear(x, m_splits=m_splits) -te_ms = bench_cuda(te_run, warmup=20, iters=100, name="TE GroupedLinear") +te_ms = bench_cuda(te_run, warmup=20, iters=100) # Grab weights for reference path Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] @@ -69,7 +67,7 @@ def torch_run(): y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) return y_ref_buf -torch_ms = bench_cuda(torch_run, warmup=20, iters=100, name="Torch loop (prealloc out)") +torch_ms = bench_cuda(torch_run, warmup=20, iters=100) # Compare outputs y_te = te_run() @@ -79,7 +77,7 @@ def torch_run(): max_abs = diff.abs().max().item() rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() -print(f"\nErrors:") +print(f"Errors:") print(f" {y_te.shape=}, {y_ref.shape=}") print(" max_abs_err:", max_abs) print(" max_rel_err:", rel) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh index 1171e33f9..2ae402c47 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -82,11 +82,11 @@ public: template static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, + dim3 grids, ck_tile::index_t group_num, void* kargs_dev) { const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); ck_tile::launch_kernel( s, @@ -169,6 +169,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, stride_E); } + const dim3 grids = Kernel::GridSize(descs); auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); @@ -182,7 +183,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, stream)); const ck_tile::stream_config s{stream}; - launch_tileloop_kernel(s, group_num, workspace); + launch_tileloop_kernel(s, grids, group_num, workspace); return true; }