From b740212d668f64d718ca1d17155d4687be7df7bd Mon Sep 17 00:00:00 2001 From: ivandobskygithub <84400859+ivandobskygithub@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:51:40 +0000 Subject: [PATCH] Guard Q TMA setup behind Use_TMA_Q --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 393 ++++++++++++++--------- 1 file changed, 233 insertions(+), 160 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 536ff855fd4..68a05c585d6 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -245,34 +245,64 @@ struct CollectiveMainloopFwdSm90 { using StrideRotary = cute::Stride; using StrideDescale = cute::Stride; - using TMA_Q = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{})); - - using TMA_K = decltype(make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{})); // mcast along M mode for this N load, if any - - using TMA_V = decltype(make_tma_copy( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - - using TMA_Qv_ = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQv{}, - TileShape_MNK_QV{}, - ClusterShape{})); - using TMA_Qv = std::conditional_t; + struct DummyTmaDescriptor {}; + + static constexpr auto make_tma_q_type() { + if constexpr (Use_TMA_Q) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_k_type() { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_v_type() { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_qv_type() { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + using TMA_Q = decltype(make_tma_q_type()); + using TMA_K = decltype(make_tma_k_type()); + using TMA_V = decltype(make_tma_v_type()); + using TMA_Qv = decltype(make_tma_qv_type()); // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); @@ -458,44 +488,74 @@ struct CollectiveMainloopFwdSm90 { static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); - TMA_Q tma_load_Q = make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - mQ, - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{}); // no mcast for Q + TMA_Q tma_load_Q = [&] { + if constexpr (Use_TMA_Q) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for Q + } else { + return TMA_Q{}; + } + }(); Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); - TMA_K tma_load_K = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - mK, - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any + TMA_K tma_load_K = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + } else { + return TMA_K{}; + } + }(); Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), select<1, 0, 2, 3>(args.stride_V)); - TMA_V tma_load_V = make_tma_copy( - GmemTiledCopyKV{}, - mV, - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TMA_V tma_load_V = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + mV, + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } else { + return TMA_V{}; + } + }(); Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); - TMA_K tma_load_K_new = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - cute::conditional_return(mKnew, mK), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any + TMA_K tma_load_K_new = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + cute::conditional_return(mKnew, mK), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + } else { + return TMA_K{}; + } + }(); Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), select<1, 0, 2, 3>(args.stride_V_new)); - TMA_V tma_load_V_new = make_tma_copy( - GmemTiledCopyKV{}, - cute::conditional_return(mVnew, mV), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TMA_V tma_load_V_new = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + cute::conditional_return(mVnew, mV), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } else { + return TMA_V{}; + } + }(); auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); TMA_Qv tma_load_Qv = [&] { @@ -507,7 +567,7 @@ struct CollectiveMainloopFwdSm90 { TileShape_MNK_QV{}, ClusterShape{}); // no mcast for Qv } else { - return nullptr; + return TMA_Qv{}; } }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) @@ -580,7 +640,7 @@ struct CollectiveMainloopFwdSm90 { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } - if constexpr (AppendKV) { + if constexpr (AppendKV && Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); } @@ -652,20 +712,14 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) - auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); - Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) - Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) - if (Use_TMA_Q && thread_idx == 0) { prefetch(params.tma_load_Q, tQgQ); } // tma_partition doesn't handle position_independent_swizzle_tensor correctly, so we need to do it manually auto block_tma_K = params.tma_load_K.get_slice(cluster_local_block_id.x); Tensor tKgK_TMA = group_modes<0, 3>(block_tma_K.partition_S(gK_TMA)); // (TMA, k, batch) @@ -673,19 +727,6 @@ struct CollectiveMainloopFwdSm90 { auto block_tma_V = params.tma_load_V.get_slice(cluster_local_block_id.x); Tensor tVgVt_TMA = group_modes<0, 3>(block_tma_V.partition_S(gVt_TMA)); // (TMA, k, batch) Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V.partition_D(sVt)); // (TMA, PIPE) - auto [tQvgQv, tQvsQv] = [&] { - if constexpr (HasQv) { - auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); - Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) - auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); - Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) - Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) - return cute::make_tuple(tQvgQv, tQvsQv); - } else { - return cute::make_tuple(nullptr, nullptr); - } - }(); // This is used to index into the batch dimension of mK and mV int const bidb_kv_idx = !is_varlen_k && !params.ptr_pagetable ? bidb_kv : 0; @@ -810,6 +851,26 @@ struct CollectiveMainloopFwdSm90 { } if constexpr (Use_TMA_Q) { + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) + Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA) + if (thread_idx == 0) { prefetch(block_tma_Q, tQgQ); } + auto [tQvgQv, tQvsQv] = [&] { + if constexpr (HasQv) { + auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q)); + Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv) + auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{}); + Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA) + Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA) + return cute::make_tuple(tQvgQv, tQvsQv); + } else { + return cute::make_tuple(nullptr, nullptr); + } + }(); + // Wait for the MMA warpgroups to signal that smem_q is ready if (SingleProducerWarp || warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); @@ -1453,97 +1514,109 @@ struct CollectiveMainloopFwdSm90 { cute::tuple block_coord, int const work_idx ) { + if constexpr (Use_TMA_KV) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.attention_chunk_divmod, - params.qhead_per_khead_divmod); + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); - if (n_block_new_max <= n_block_new_min) { return false; } + if (n_block_new_max <= n_block_new_min) { return false; } - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sVt = [&] { - if constexpr (!Transpose_V) { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - } else { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + } + }(); + + // int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + + auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); + Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); + Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } } - }(); - // int const thread_idx = threadIdx.x % NumProducerThreads; - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; - Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - - Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_k_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + }; - auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); - Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) - Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) - auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); - Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) - Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_v_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + }; - uint16_t mcast_mask_kv = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); + + int n_block = n_block_new_max - 1; + // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV + // and the main attention are not the same. We want to make sure the consumers + // have finished reading all smem_k and smem_v for the previous iteration. + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_write; + --n_block; + // if (thread_idx == 0) { printf("Producer: before for loop\n"); } + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if (should_load_KV) { + load_K_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + load_V_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + } + ++smem_pipe_write; } + // if (thread_idx == 0) { printf("Producer: after for loop\n"); } + // At the end, all threads have the correct smem_pipe_write. + return true; } - auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_k_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); - }; - - auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_v_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); - }; - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - // If this is true, we're guaranteed that only the first warp will execute this function - static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; - bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); - - int n_block = n_block_new_max - 1; - // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV - // and the main attention are not the same. We want to make sure the consumers - // have finished reading all smem_k and smem_v for the previous iteration. - shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); - if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - ++smem_pipe_write; - --n_block; - // if (thread_idx == 0) { printf("Producer: before for loop\n"); } - #pragma unroll 1 - for (; n_block >= n_block_new_min; --n_block) { - if (should_load_KV) { - load_K_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - load_V_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - } - ++smem_pipe_write; - } - // if (thread_idx == 0) { printf("Producer: after for loop\n"); } - // At the end, all threads have the correct smem_pipe_write. - return true; + (void)params; + (void)pipeline_k_new; + (void)pipeline_v_new; + (void)smem_pipe_write; + (void)shared_storage; + (void)seqlen_info; + (void)block_coord; + (void)work_idx; + return false; } template