diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 68a05c585d6..7f52bbea904 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -646,6 +646,83 @@ struct CollectiveMainloopFwdSm90 { } } + template + CUTLASS_DEVICE void load_q_impl( + std::true_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, + int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, Tensor sQ, Tensor sQv) { + + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + if (thread_idx == 0) { prefetch(block_tma_Q, tQgQ); } + + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); + + // Wait for the MMA warpgroups to signal that smem_q is ready + if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + } + + if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { + shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQgQ, tQsQ); + if constexpr (HasQv) { + shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); + copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), + tQvgQv, tQvsQv); + } + } + } + + template + CUTLASS_DEVICE void load_q_impl( + std::false_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage, + SeqlenInfo_t const& seqlen_info, cute::tuple block_coord, + int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, Tensor sQ, Tensor sQv) { + + int const m_block = get<0>(block_coord); + int const bidh = get<1>(block_coord); + int const bidb = get<2>(block_coord); + + cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); + using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Q = shared_storage.pipelines.barrier_Q; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); + barrier_Q.arrive(); + if constexpr (HasQv) { + Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); + using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; + PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); + auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; + cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); + barrier_Qv.arrive(); + } + } + template CUTLASS_DEVICE void load(Params const& params, @@ -850,61 +927,7 @@ struct CollectiveMainloopFwdSm90 { // if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());} } - if constexpr (Use_TMA_Q) { - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); - Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) - Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) - if (thread_idx == 0) { prefetch(block_tma_Q, tQgQ); } - auto [tQvgQv, tQvsQv] = [&] { - if constexpr (HasQv) { - auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); - Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) - auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); - Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) - Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) - return cute::make_tuple(tQvgQv, tQvsQv); - } else { - return cute::make_tuple(nullptr, nullptr); - } - }(); - - // Wait for the MMA warpgroups to signal that smem_q is ready - if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { - cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - } - - if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) { - shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); - copy(params.tma_load_Q.with(reinterpret_cast(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), - tQgQ, tQsQ); - if constexpr (HasQv) { - shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv); - copy(params.tma_load_Qv.with(reinterpret_cast(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST), - tQvgQv, tQvsQv); - } - } - } else { // Load Q with cp.async - cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); - Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ); - using PackGQAt = flash::PackGQAManager(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>; - PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); - auto &barrier_Q = shared_storage.pipelines.barrier_Q; - cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Q)); - barrier_Q.arrive(); - if constexpr (HasQv) { - Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv); - using PackGQAt = flash::PackGQAManager(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>; - PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block); - auto &barrier_Qv = shared_storage.pipelines.barrier_Qv; - cutlass::arch::cpasync_barrier_arrive(reinterpret_cast(&barrier_Qv)); - barrier_Qv.arrive(); - } - } + load_q_impl(std::integral_constant{}, params, shared_storage, seqlen_info, block_coord, thread_idx, is_varlen_q, warp_idx_in_warpgroup, sQ, sQv); // Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the