diff --git a/gmm2.py b/gmm2.py new file mode 100644 index 000000000..8966afa02 --- /dev/null +++ b/gmm2.py @@ -0,0 +1,90 @@ +import os +import time +import torch +import transformer_engine.pytorch as te + +torch.manual_seed(0) + +os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" +os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" + +device = "cuda" +dtype = torch.bfloat16 + +E = 4 +K = 1024 +N = 2048 +m_splits = [128, 64, 0, 256] +M_total = sum(m_splits) + +x = torch.randn(M_total, K, device=device, dtype=dtype) + +# Timing helper +def bench_cuda(fn, warmup=20, iters=100): + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + # Timed + start = time.time() + for _ in range(iters): + fn() + torch.cuda.synchronize() + end = time.time() + + avg_ms = (end - start) * 1000.0 / iters + return avg_ms + +# TE GroupedLinear +glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) + +def te_run(): + return glinear(x, m_splits=m_splits) + +te_ms = bench_cuda(te_run, warmup=20, iters=100) + +# Grab weights for reference path +Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] +W = torch.stack(Ws, dim=0) # [E, N, K] +assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" + +# Torch reference (group loop) +offsets = [] +off = 0 +for m in m_splits: + offsets.append(off) + off += m + +y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype) + +def torch_run(): + # Fill the preallocated buffer + for e, m in enumerate(m_splits): + if m == 0: + continue + o = offsets[e] + y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) + return y_ref_buf + +torch_ms = bench_cuda(torch_run, warmup=20, iters=100) + +# Compare outputs +y_te = te_run() +y_ref = torch_run().clone() + +diff = (y_te.float() - y_ref.float()) +max_abs = diff.abs().max().item() +rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() + +print(f"Errors:") +print(f" {y_te.shape=}, {y_ref.shape=}") +print(" max_abs_err:", max_abs) +print(" max_rel_err:", rel) + +torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) + +print(f"\nTiming:") +print(f" TE avg: {te_ms:.3f} ms") +print(f" Torch avg: {torch_ms:.3f} ms") +print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a4dfd64ba..5f1489f88 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1385,7 +1385,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") te_linear_ref = Linear( config.hidden_size, @@ -1677,7 +1677,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ): if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") config = model_configs[model] ln_linear_ref = LayerNormLinear( @@ -1891,7 +1891,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, @@ -2036,7 +2036,7 @@ def test_grouped_linear_accuracy( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") @@ -2115,7 +2115,7 @@ def test_grouped_linear_accuracy( @pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, reason="Only enable CUTLASS grouped gemm on Hopper", ) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2133,6 +2133,9 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + if IS_HIP_EXTENSION: + os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" + os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2147,6 +2150,9 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + if IS_HIP_EXTENSION: + os.environ.pop("NVTE_USE_CK_GROUPED_GEMM", None) + os.environ.pop("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", None) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..56207f16d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -241,6 +241,14 @@ endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + +target_include_directories(transformer_engine + BEFORE PRIVATE + ${CK_ROOT}/include +) + + if (USE_CUDA) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh new file mode 100644 index 000000000..2ae402c47 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -0,0 +1,387 @@ +/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template +struct TeDTypeToCk; +template <> struct TeDTypeToCk { using type = ck_tile::half_t; }; +template <> struct TeDTypeToCk{ using type = ck_tile::bfloat16_t; }; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; // rowwise data view +} + +struct TileCfg_basic { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 1; +}; + +template +class Runner{ +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; +}; + +template +static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, + dim3 grids, + ck_tile::index_t group_num, + void* kargs_dev) +{ + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_dev), + group_num)); +} + +template +static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, + const transformer_engine::Tensor* const* B_use, + transformer_engine::Tensor* const* D, + int group_num, + bool transA_use, + bool transB_use, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) +{ + using R = Runner; + using Kernel = typename R::Kernel; + + const size_t needed = Kernel::GetWorkSpaceSize(group_num); + if (!workspace || workspace_bytes < needed) { + NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + return false; + } + + std::vector> descs; + descs.reserve(group_num); + + for (int i = 0; i < group_num; ++i) { + const auto& a = data_view(*A_use[i]); + const auto& b = data_view(*B_use[i]); + const auto& d = data_view(*D[i]); + + if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + return false; + } + + const int64_t Ad0 = a.shape[0]; + const int64_t Ad1 = a.shape[1]; + const int64_t Bd0 = b.shape[0]; + const int64_t Bd1 = b.shape[1]; + + const int64_t M = transA_use ? Ad1 : Ad0; + const int64_t K = transA_use ? Ad0 : Ad1; + const int64_t N = transB_use ? Bd0 : Bd1; + const int64_t Kb = transB_use ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group ", i); + return false; + } + + if (d.shape[0] != M || d.shape[1] != N) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + return false; + } + + const ck_tile::index_t stride_A = a.shape[1]; + const ck_tile::index_t stride_B = b.shape[1]; + const ck_tile::index_t stride_E = d.shape[1]; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); + return false; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); + + const ck_tile::stream_config s{stream}; + launch_tileloop_kernel(s, grids, group_num, workspace); + return true; +} + +static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + const transformer_engine::Tensor* const*& A_use, + const transformer_engine::Tensor* const*& B_use, + bool& transA_use, + bool& transB_use) +{ + A_use = A; + B_use = B; + transA_use = false; + transB_use = false; + + if (group_num <= 0) + return true; + + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); + const auto& d0 = data_view(*D[0]); + + if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { + return false; + } + + const int64_t Ad0 = a0.shape[0]; + const int64_t Ad1 = a0.shape[1]; + const int64_t Bd0 = b0.shape[0]; + const int64_t Bd1 = b0.shape[1]; + const int64_t Dm = d0.shape[0]; + const int64_t Dn = d0.shape[1]; + + auto check = [&](bool do_swap, bool ta, bool tb) -> bool { + const int64_t A0d0 = do_swap ? Bd0 : Ad0; + const int64_t A0d1 = do_swap ? Bd1 : Ad1; + const int64_t B0d0 = do_swap ? Ad0 : Bd0; + const int64_t B0d1 = do_swap ? Ad1 : Bd1; + + const int64_t M = ta ? A0d1 : A0d0; + const int64_t K = ta ? A0d0 : A0d1; + const int64_t N = tb ? B0d0 : B0d1; + const int64_t Kb = tb ? B0d1 : B0d0; + + return (M == Dm) && (N == Dn) && (K == Kb); + }; + + // Try all candidates; prefer "no swap" first, then swap. + for (bool do_swap : {false, true}) { + for (bool ta : {false, true}) { + for (bool tb : {false, true}) { + if (check(do_swap, ta, tb)) { + A_use = do_swap ? B : A; + B_use = do_swap ? A : B; + transA_use = ta; + transB_use = tb; + return true; + } + } + } + } + + // Nothing matched D = op(A) * op(B) + return false; +} + +bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + bool transA, + bool transB, + void* workspace, + size_t workspace_bytes, + bool accumulate, + hipStream_t stream) +{ + const transformer_engine::Tensor* const* A_use = A; + const transformer_engine::Tensor* const* B_use = B; + bool transA_use = transA; + bool transB_use = transB; + + // If TE's flags disagree with storage, infer the correct mode from shapes. + if (!infer_gemm_mode_group0(A, B, D, group_num, A_use, B_use, transA_use, transB_use)) { + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); + const auto& d0 = data_view(*D[0]); + NVTE_ERROR("grouped_gemm_ck_tile: could not infer a consistent GEMM mode from shapes. ", + "A0=[", a0.shape[0], ",", a0.shape[1], "] ", + "B0=[", b0.shape[0], ",", b0.shape[1], "] ", + "D0=[", d0.shape[0], ",", d0.shape[1], "] ", + "given flags transA=", transA, " transB=", transB); + return false; + } + + const auto a_dtype = A_use[0]->dtype(); + + const auto memop = accumulate ? ck_tile::memory_operation_enum::atomic_add + : ck_tile::memory_operation_enum::set; + + if (a_dtype == transformer_engine::DType::kFloat16) { + using T = TeDTypeToCk::type; + + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); + } else { + using T = TeDTypeToCk::type; + + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); + } +} + +bool grouped_gemm_ck_tile(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream) +{ + if (group_num <= 0) + return true; + + // Convert A/B/D arrays into TE Tensor* arrays + std::vector A_te(group_num); + std::vector B_te(group_num); + std::vector D_te(group_num); + + for (int i = 0; i < group_num; ++i) { + A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); + B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); + D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); + } + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * transformer_engine::typeToSize(ws_te->data.dtype); + } + + return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), + group_num, transA, transB, + ws_ptr, ws_bytes, accumulate, + stream); +} diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..fcbdac91c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -24,8 +24,11 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" +#else +#include "ck_grouped_gemm.cuh" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -788,7 +791,35 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + const bool use_ck = transformer_engine::getenv("NVTE_USE_CK_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; + auto D_dt = OutputD->data.dtype; + + return (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16); + }; + + if (use_ck && is_supported_dtype()) { + if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + // NVTE_WARN("grouped_gemm_ck_tile done.\n"); + return; + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); + } + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported or CK disabled). use_ck=", use_ck, " is_supported_dtype=", is_supported_dtype()); + } + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else const int current_device = transformer_engine::cuda::current_device();