From 9078bb4a603eb071af4c55cd1c9533100d5115ba Mon Sep 17 00:00:00 2001 From: ivandobskygithub <84400859+ivandobskygithub@users.noreply.github.com> Date: Tue, 25 Nov 2025 13:16:41 +0000 Subject: [PATCH] Guard non-TMA AppendKV paths --- hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp | 354 ++++++++++++++--------- 1 file changed, 213 insertions(+), 141 deletions(-) diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 536ff855fd4..68129978dcd 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -245,34 +245,64 @@ struct CollectiveMainloopFwdSm90 { using StrideRotary = cute::Stride; using StrideDescale = cute::Stride; - using TMA_Q = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{})); - - using TMA_K = decltype(make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{})); // mcast along M mode for this N load, if any - - using TMA_V = decltype(make_tma_copy( - GmemTiledCopyKV{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - - using TMA_Qv_ = decltype(make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), - SmemLayoutQv{}, - TileShape_MNK_QV{}, - ClusterShape{})); - using TMA_Qv = std::conditional_t; + struct DummyTmaDescriptor {}; + + static constexpr auto make_tma_q_type() { + if constexpr (Use_TMA_Q) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_k_type() { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_v_type() { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, select<1, 0, 2, 3>(StrideV{})), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); + } else { + return DummyTmaDescriptor{}; + } + } + + static constexpr auto make_tma_qv_type() { + if constexpr (HasQv) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + SmemLayoutQv{}, + TileShape_MNK_QV{}, + ClusterShape{}); + } else { + return DummyTmaDescriptor{}; + } + } + + using TMA_Q = decltype(make_tma_q_type()); + using TMA_K = decltype(make_tma_k_type()); + using TMA_V = decltype(make_tma_v_type()); + using TMA_Qv = decltype(make_tma_qv_type()); // Set the bytes transferred in this TMA transaction (may involve multiple issues) static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); @@ -458,44 +488,74 @@ struct CollectiveMainloopFwdSm90 { static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); - TMA_Q tma_load_Q = make_tma_copy_A_sm90( - GmemTiledCopyQ{}, - mQ, - SmemLayoutQ{}, - TileShape_MNK{}, - ClusterShape{}); // no mcast for Q + TMA_Q tma_load_Q = [&] { + if constexpr (Use_TMA_Q) { + return make_tma_copy_A_sm90( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + TileShape_MNK{}, + ClusterShape{}); // no mcast for Q + } else { + return TMA_Q{}; + } + }(); Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); - TMA_K tma_load_K = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - mK, - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any + TMA_K tma_load_K = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + mK, + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + } else { + return TMA_K{}; + } + }(); Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), make_shape(args.headdim_v, get<0>(args.shape_K), get<2>(args.shape_K), get<3>(args.shape_K)), select<1, 0, 2, 3>(args.stride_V)); - TMA_V tma_load_V = make_tma_copy( - GmemTiledCopyKV{}, - mV, - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TMA_V tma_load_V = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + mV, + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } else { + return TMA_V{}; + } + }(); Tensor mKnew = make_tensor(make_gmem_ptr(args.ptr_K_new), args.shape_K_new, args.stride_K_new); - TMA_K tma_load_K_new = make_tma_copy_B_sm90( - GmemTiledCopyKV{}, - cute::conditional_return(mKnew, mK), - take<0, 2>(SmemLayoutK{}), - TileShape_MNK{}, - ClusterShape{}); // mcast along M mode for this N load, if any + TMA_K tma_load_K_new = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy_B_sm90( + GmemTiledCopyKV{}, + cute::conditional_return(mKnew, mK), + take<0, 2>(SmemLayoutK{}), + TileShape_MNK{}, + ClusterShape{}); // mcast along M mode for this N load, if any + } else { + return TMA_K{}; + } + }(); Tensor mVnew = make_tensor(make_gmem_ptr(args.ptr_V_new), make_shape(args.headdim_v, get<0>(args.shape_K_new), get<2>(args.shape_K_new), get<3>(args.shape_K_new)), select<1, 0, 2, 3>(args.stride_V_new)); - TMA_V tma_load_V_new = make_tma_copy( - GmemTiledCopyKV{}, - cute::conditional_return(mVnew, mV), - take<0, 2>(SmemLayoutVt{}), - select<1, 2>(TileShape_MNK_PV{}), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + TMA_V tma_load_V_new = [&] { + if constexpr (Use_TMA_KV) { + return make_tma_copy( + GmemTiledCopyKV{}, + cute::conditional_return(mVnew, mV), + take<0, 2>(SmemLayoutVt{}), + select<1, 2>(TileShape_MNK_PV{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + } else { + return TMA_V{}; + } + }(); auto shape_Qv = make_shape(get<0>(args.shape_Q), args.headdim_v, get<2>(args.shape_Q), get<3>(args.shape_Q)); Tensor mQv = make_tensor(make_gmem_ptr(args.ptr_Qv), shape_Qv, args.stride_Qv); TMA_Qv tma_load_Qv = [&] { @@ -507,7 +567,7 @@ struct CollectiveMainloopFwdSm90 { TileShape_MNK_QV{}, ClusterShape{}); // no mcast for Qv } else { - return nullptr; + return TMA_Qv{}; } }(); // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) @@ -580,7 +640,7 @@ struct CollectiveMainloopFwdSm90 { cute::prefetch_tma_descriptor(params.tma_load_K.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V.get_tma_descriptor()); } - if constexpr (AppendKV) { + if constexpr (AppendKV && Use_TMA_KV) { cute::prefetch_tma_descriptor(params.tma_load_K_new.get_tma_descriptor()); cute::prefetch_tma_descriptor(params.tma_load_V_new.get_tma_descriptor()); } @@ -1453,97 +1513,109 @@ struct CollectiveMainloopFwdSm90 { cute::tuple block_coord, int const work_idx ) { + if constexpr (Use_TMA_KV) { - auto [m_block, bidh, bidb, split_idx] = block_coord; - auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( - seqlen_info, m_block, bidb, split_idx, params.num_splits, - params.window_size_left, params.window_size_right, params.attention_chunk_divmod, - params.qhead_per_khead_divmod); + auto [m_block, bidh, bidb, split_idx] = block_coord; + auto [n_block_new_min, n_block_new_max] = BlockMN_t::get_n_block_k_new_min_max( + seqlen_info, m_block, bidb, split_idx, params.num_splits, + params.window_size_left, params.window_size_right, params.attention_chunk_divmod, + params.qhead_per_khead_divmod); - if (n_block_new_max <= n_block_new_min) { return false; } + if (n_block_new_max <= n_block_new_min) { return false; } - Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); - Tensor sVt = [&] { - if constexpr (!Transpose_V) { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); - } else { - return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_k.data()), SmemLayoutK{}); + Tensor sVt = [&] { + if constexpr (!Transpose_V) { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_v.data()), SmemLayoutVt{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.tensors.mainloop.smem_vt.data()), SmemLayoutVt{}); + } + }(); + + // int const thread_idx = threadIdx.x % NumProducerThreads; + int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; + Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); + Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); + + Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + + auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); + Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) + Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) + auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); + Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) + Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } } - }(); - - // int const thread_idx = threadIdx.x % NumProducerThreads; - int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - // Prepare the TMA loads - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - bool const is_varlen_k_new = Varlen && params.cu_seqlens_k_new; - Tensor mKnew_TMA = params.tma_load_K_new.get_tma_tensor(params.shape_K_new)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - auto shape_Vnew = make_shape(params.headdim_v, get<0>(params.shape_K_new), get<2>(params.shape_K_new), get<3>(params.shape_K_new)); - Tensor mVnewt_TMA = params.tma_load_V_new.get_tma_tensor(shape_Vnew)(_, _, bidh_kv, !is_varlen_k_new ? bidb : 0); - - Tensor gKnew_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k_new, _0{}), mKnew_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor gVnewt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k_new), mVnewt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _)); // (K, N, _) + auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_k_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); + }; - auto block_tma_K_new = params.tma_load_K_new.get_slice(cluster_local_block_id.x); - Tensor tKgKnew_TMA = group_modes<0, 3>(block_tma_K_new.partition_S(gKnew_TMA)); // (TMA, k) - Tensor tKsK_TMA = group_modes<0, 3>(block_tma_K_new.partition_D(sK)); // (TMA, PIPE) - auto block_tma_V_new = params.tma_load_V_new.get_slice(cluster_local_block_id.x); - Tensor tVgVnewt_TMA = group_modes<0, 3>(block_tma_V_new.partition_S(gVnewt_TMA)); // (TMA, k) - Tensor tVsVt_TMA = group_modes<0, 3>(block_tma_V_new.partition_D(sVt)); // (TMA, PIPE) + auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { + pipeline_v_new.producer_acquire(smem_pipe_write); + copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), + tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); + }; - uint16_t mcast_mask_kv = 0; - if constexpr (cute::is_same_v) { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) { - mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + // If this is true, we're guaranteed that only the first warp will execute this function + static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; + bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); + + int n_block = n_block_new_max - 1; + // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV + // and the main attention are not the same. We want to make sure the consumers + // have finished reading all smem_k and smem_v for the previous iteration. + shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); + if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + ++smem_pipe_write; + --n_block; + // if (thread_idx == 0) { printf("Producer: before for loop\n"); } + #pragma unroll 1 + for (; n_block >= n_block_new_min; --n_block) { + if (should_load_KV) { + load_K_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + load_V_new(n_block, smem_pipe_write); + // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } + } + ++smem_pipe_write; } + // if (thread_idx == 0) { printf("Producer: after for loop\n"); } + // At the end, all threads have the correct smem_pipe_write. + return true; } - auto load_K_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_k_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_K_new.with(*pipeline_k_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tKgKnew_TMA(_, n_block), tKsK_TMA(_, smem_pipe_write.index())); - }; - - auto load_V_new = [&] (int const n_block, auto const& smem_pipe_write) { - pipeline_v_new.producer_acquire(smem_pipe_write); - copy(params.tma_load_V_new.with(*pipeline_v_new.producer_get_barrier(smem_pipe_write), mcast_mask_kv, TMA::CacheHintSm90::EVICT_FIRST), - tVgVnewt_TMA(_, n_block), tVsVt_TMA(_, smem_pipe_write.index())); - }; - - int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); - // If this is true, we're guaranteed that only the first warp will execute this function - static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp; - bool should_load_KV = (SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync(); - - int n_block = n_block_new_max - 1; - // Need to wait for barrier_O even before load_K_new since the pipelines for AppendKV - // and the main attention are not the same. We want to make sure the consumers - // have finished reading all smem_k and smem_v for the previous iteration. - shared_storage.pipelines.barrier_O.wait((work_idx + 1) % 2); - if (should_load_KV) { load_K_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - if (should_load_KV) { load_V_new(n_block, smem_pipe_write); } - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - ++smem_pipe_write; - --n_block; - // if (thread_idx == 0) { printf("Producer: before for loop\n"); } - #pragma unroll 1 - for (; n_block >= n_block_new_min; --n_block) { - if (should_load_KV) { - load_K_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading K, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - load_V_new(n_block, smem_pipe_write); - // if (thread_idx == 0) { printf("Producer: Done loading V, n_block = %d, n_block_new_min = %d\n", n_block, n_block_new_min); } - } - ++smem_pipe_write; - } - // if (thread_idx == 0) { printf("Producer: after for loop\n"); } - // At the end, all threads have the correct smem_pipe_write. - return true; + (void)params; + (void)pipeline_k_new; + (void)pipeline_v_new; + (void)smem_pipe_write; + (void)shared_storage; + (void)seqlen_info; + (void)block_coord; + (void)work_idx; + return false; } template