From 0e35b90e475be4767737009a5443cdb16e489962 Mon Sep 17 00:00:00 2001 From: ivandobskygithub <84400859+ivandobskygithub@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:51:45 +0000 Subject: [PATCH] Fix load_q_impl tensor templates --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 7f52bbea904..91eed010dcc 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -119,6 +119,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int NumMmaThreadsQK = size(TiledMmaQK{}); static constexpr int NumMmaThreads = size(TiledMmaPV{}); static constexpr int NumProducerThreads = !Transpose_V && Use_TMA_KV && Use_TMA_Q ? cutlass::NumThreadsPerWarp : cutlass::NumThreadsPerWarpGroup; + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; static_assert(NumMmaThreadsQK % cutlass::NumThreadsPerWarpGroup == 0); static_assert(NumMmaThreads % cutlass::NumThreadsPerWarpGroup == 0); static constexpr int NumMmaWarpGroups = NumMmaThreads / cutlass::NumThreadsPerWarpGroup; @@ -646,11 +647,11 @@ struct CollectiveMainloopFwdSm90 { } } - template + template CUTLASS_DEVICE void load_q_impl( std::true_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, - int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, Tensor sQ, Tensor sQv) { + int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, TensorQ sQ, TensorQv sQv) { int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord); @@ -694,11 +695,11 @@ struct CollectiveMainloopFwdSm90 { } } - template + template CUTLASS_DEVICE void load_q_impl( std::false_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage, SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, - int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, Tensor sQ, Tensor sQv) { + int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, TensorQ sQ, TensorQv sQv) { int const m_block = get<0>(block_coord); int const bidh = get<1>(block_coord);