Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
Expand Down Expand Up @@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
Expand All @@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);

// Wide version of KQ_C is column-major
if constexpr (cols_per_warp == 8) {
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
#if defined(AMD_WMMA_AVAILABLE)
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
// RDNA matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
#endif // defined(AMD_WMMA_AVAILABLE)
}
}
}
}
Expand Down Expand Up @@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(

constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
Expand Down Expand Up @@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
#ifdef VOLTA_MMA_AVAILABLE
if (ncols1*ncols2 < 32) {
NO_DEVICE_CODE;
return;
}
#endif // VOLTA_MMA_AVAILABLE

#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
Expand Down Expand Up @@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);

// For GLM 4.7 Flash
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/fattn-tile.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)

return 0;
Expand Down Expand Up @@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)

return 0;
Expand Down Expand Up @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)

Expand Down Expand Up @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)

GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)

Expand Down Expand Up @@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 4 == 0) {
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
return;
}
}

if constexpr (DV <= 256) {
Expand Down
10 changes: 7 additions & 3 deletions ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg

GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
GGML_ASSERT(gqa_ratio % 4 == 0);
if (gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
}
} break;
default:
GGML_ABORT("fatal error");
Expand Down Expand Up @@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/template-instances/generate_cu_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_short_name(long_quant_name):
continue
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:
if head_size_kq == 576 and ncols2 not in (4, 16):
continue
head_size_v = head_size_kq if head_size_kq != 576 else 512
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
Expand Down
Loading