Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 78 additions & 55 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,83 @@ struct CollectiveMainloopFwdSm90 {
}
}

template <typename SharedStorage>
CUTLASS_DEVICE void load_q_impl(
std::true_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage,
SeqlenInfo_t const& seqlen_info, cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
int thread_idx, bool is_varlen_q, int warp_idx_in_warpgroup, Tensor sQ, Tensor sQv) {

int const m_block = get<0>(block_coord);
int const bidh = get<1>(block_coord);
int const bidb = get<2>(block_coord);

Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_Q = params.tma_load_Q.get_slice(_0{});
Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA)
Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA)
if (thread_idx == 0) { prefetch(block_tma_Q, tQgQ); }

auto [tQvgQv, tQvsQv] = [&] {
if constexpr (HasQv) {
auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q));
Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv)
auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{});
Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA)
Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA)
return cute::make_tuple(tQvgQv, tQvsQv);
} else {
return cute::make_tuple(nullptr, nullptr);
}
}();

// Wait for the MMA warpgroups to signal that smem_q is ready
if (SingleProducerWarp || warp_idx_in_warpgroup == 0) {
cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
}

if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) {
shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),
tQgQ, tQsQ);
if constexpr (HasQv) {
shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv);
copy(params.tma_load_Qv.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),
tQvgQv, tQvsQv);
}
}
}

template <typename SharedStorage>
CUTLASS_DEVICE void load_q_impl(
std::false_type /*UseTmaQ*/, Params const& params, SharedStorage &shared_storage,
SeqlenInfo_t const& seqlen_info, cute::tuple<int32_t, int32_t, int32_t, int32_t> block_coord,
int thread_idx, bool is_varlen_q, int /*warp_idx_in_warpgroup*/, Tensor sQ, Tensor sQv) {

int const m_block = get<0>(block_coord);
int const bidh = get<1>(block_coord);
int const bidb = get<2>(block_coord);

cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>;
PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);
auto &barrier_Q = shared_storage.pipelines.barrier_Q;
cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Q));
barrier_Q.arrive();
if constexpr (HasQv) {
Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv);
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>;
PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);
auto &barrier_Qv = shared_storage.pipelines.barrier_Qv;
cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Qv));
barrier_Qv.arrive();
}
}

template <typename SchedulerPrefetch, typename SharedStorage>
CUTLASS_DEVICE void
load(Params const& params,
Expand Down Expand Up @@ -850,61 +927,7 @@ struct CollectiveMainloopFwdSm90 {
// if (thread_idx == 0) { printf("Producer: main load, after load K, index = %d\n", smem_pipe_write.index());}
}

if constexpr (Use_TMA_Q) {
Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K)
auto block_tma_Q = params.tma_load_Q.get_slice(_0{});
Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA)
Tensor tQsQ = group_modes<0, 3>(block_tma_Q.partition_D(sQ)); // (TMA)
if (thread_idx == 0) { prefetch(block_tma_Q, tQgQ); }
auto [tQvgQv, tQvsQv] = [&] {
if constexpr (HasQv) {
auto shape_Qv = make_shape(get<0>(params.shape_Q), params.headdim_v, get<2>(params.shape_Q), get<3>(params.shape_Q));
Tensor mQv = params.tma_load_Qv.get_tma_tensor(shape_Qv)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor gQv = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQv), select<0, 2>(TileShape_MNK_QV{}), make_coord(m_block, _0{})); // (M, Kv)
auto block_tma_Qv = params.tma_load_Qv.get_slice(_0{});
Tensor tQvgQv = group_modes<0, 3>(block_tma_Qv.partition_S(gQv)); // (TMA)
Tensor tQvsQv = group_modes<0, 3>(block_tma_Qv.partition_D(sQv)); // (TMA)
return cute::make_tuple(tQvgQv, tQvsQv);
} else {
return cute::make_tuple(nullptr, nullptr);
}
}();

// Wait for the MMA warpgroups to signal that smem_q is ready
if (SingleProducerWarp || warp_idx_in_warpgroup == 0) {
cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
}

if ((SingleProducerWarp || warp_idx_in_warpgroup == 0) && cute::elect_one_sync()) {
shared_storage.pipelines.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(params.tma_load_Q.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Q), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),
tQgQ, tQsQ);
if constexpr (HasQv) {
shared_storage.pipelines.barrier_Qv.arrive_and_expect_tx(TmaTransactionBytesQv);
copy(params.tma_load_Qv.with(reinterpret_cast<typename cutlass::arch::ClusterTransactionBarrier::ValueType&>(shared_storage.pipelines.barrier_Qv), 0 /*mcast_mask*/, !Split ? TMA::CacheHintSm90::EVICT_FIRST : TMA::CacheHintSm90::EVICT_LAST),
tQvgQv, tQvsQv);
}
}
} else { // Load Q with cp.async
cutlass::arch::NamedBarrier::sync(NumMmaThreadsQK + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::QueryEmpty) /*id*/);
Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor sQ_pi = cute::as_position_independent_swizzle_tensor(sQ);
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK{}), get<2>(TileShape_MNK{}), NumProducerThreads, Element>;
PackGQAt::load_Q(mQ, sQ_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);
auto &barrier_Q = shared_storage.pipelines.barrier_Q;
cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Q));
barrier_Q.arrive();
if constexpr (HasQv) {
Tensor mQv = make_tensor(make_gmem_ptr(params.ptr_Qv + seqlen_info.offset_q * get<0>(params.stride_Qv)), params.shape_Qv_packed, params.stride_Qv_packed)(_, _, bidh, !is_varlen_q ? bidb : 0);
Tensor sQv_pi = cute::as_position_independent_swizzle_tensor(sQv);
using PackGQAt = flash::PackGQAManager<get<0>(TileShape_MNK_QV{}), get<2>(TileShape_MNK_QV{}), NumProducerThreads, Element>;
PackGQAt::load_Q(mQv, sQv_pi, params.qhead_per_khead_divmod, thread_idx, seqlen_info.seqlen_q, m_block);
auto &barrier_Qv = shared_storage.pipelines.barrier_Qv;
cutlass::arch::cpasync_barrier_arrive(reinterpret_cast<uint64_t*>(&barrier_Qv));
barrier_Qv.arrive();
}
}
load_q_impl(std::integral_constant<bool, Use_TMA_Q>{}, params, shared_storage, seqlen_info, block_coord, thread_idx, is_varlen_q, warp_idx_in_warpgroup, sQ, sQv);

// Wait for the MMA WGs to signal that smem_v are ready and V can be copied from gmem
// Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the
Expand Down