diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 1e84e893fb1..7079110ed7f 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -58,7 +58,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap); static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap); static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS); - static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS); + // Consumer Blackwell parts expose ~100KB of shared memory, so reduce the forward pipeline depth + // to keep the shared memory footprint under the device limit while retaining the SM90 depth on + // H100-class GPUs. + static constexpr int kStages = Arch >= 120 ? 1 : (Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS)); static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS); using TileShape_MNK = cute::Shape, Int, Int>; @@ -190,7 +193,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params); dim3 block_dims = AttnKernel::get_block_shape(); - int smem_size = AttnKernel::SharedStorageSize; + static constexpr int kSmemSize = AttnKernel::SharedStorageSize; + static_assert(Arch < 120 || kSmemSize <= 101376, + "SM120 forward kernel requires more shared memory than the consumer budget; reduce tile sizes."); + int smem_size = kSmemSize; int max_threads_per_block = 0; int smem_limit_optin = 0; int smem_limit = 0; diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 8353542c477..6039827d479 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -6,37 +6,97 @@ #include +constexpr int smem_estimate_bytes(int block_m, int block_n, int headdim, int headdim_v, int element_size) { + // Double-buffer the residency for Q/K/V and the accumulators to reflect the large SMEM footprint observed in practice. + return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size; +} + +constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int headdim_v, + int element_size, int smem_limit) { + int const smem_usage = smem_estimate_bytes(block_m, block_n, headdim, headdim_v, element_size); + if (smem_usage <= smem_limit) { + return block_n; + } + // Keep the tile width aligned to 8 to match the granularity of our block shapes while allowing tight caps. + int const denom = 2 * element_size * (headdim + headdim_v); + int max_block_n = denom > 0 ? smem_limit / denom - block_m : block_n; + if (max_block_n < 8) { max_block_n = 8; } + max_block_n = (max_block_n / 8) * 8; + return max_block_n > 0 ? max_block_n : 8; +} + +constexpr std::tuple enforce_smem_limit(int block_m, int block_n, int headdim, int headdim_v, + int element_size, int smem_limit) { + int adjusted_block_n = clamp_block_n_for_smem(block_m, block_n, headdim, headdim_v, element_size, smem_limit); + int smem_usage = smem_estimate_bytes(block_m, adjusted_block_n, headdim, headdim_v, element_size); + if (smem_usage > smem_limit && block_m > 64) { + block_m = 64; + adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit); + smem_usage = smem_estimate_bytes(block_m, adjusted_block_n, headdim, headdim_v, element_size); + } + if (smem_usage > smem_limit) { + int const denom = 2 * element_size * (headdim + headdim_v); + int max_block_m = denom > 0 ? smem_limit / denom - adjusted_block_n : block_m; + if (max_block_m < 8) { max_block_m = 8; } + max_block_m = (max_block_m / 8) * 8; + block_m = max_block_m > 0 ? max_block_m : 8; + adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit); + } + return {block_m, adjusted_block_n}; +} + // Return {kBlockM, kBlockN, MmaPV_is_RS, IntraWGOverlap} constexpr std::tuple tile_size_fwd_sm90( int headdim, int headdim_v, bool is_causal, bool is_local, int element_size=2, bool v_colmajor=false, bool paged_kv_non_TMA=false, bool softcap=false) { + constexpr int kSm120ConsumerSmemLimit = 101376; if (element_size == 2) { if (headdim <= 64) { // return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim}; // With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why // https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131 if (headdim_v == 512) { - return {64, 64, false, false}; + // Keep the tile narrow to avoid blowing past the consumer shared-memory budget when values are very wide. + auto const [block_m, block_n] = enforce_smem_limit(64, 64, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n, false, false}; } else if (headdim_v == 256) { - return {128, 96, true, false}; + auto const [block_m, block_n] = enforce_smem_limit(64, 80, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n, true, true}; } else { // Switch to tile size 192 x 192 for now bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; - return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + auto const [block_m, block_n] = enforce_smem_limit(192, use_blockN_128 ? 128 : 192, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; } else if (headdim <= 96) { - return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; + // Large value dimensions inflate smem usage even at modest head sizes, so bias toward smaller tiles for dv >= 256. + int const block_n = headdim_v >= 256 ? 96 : (is_local || paged_kv_non_TMA ? 128 : 144); + auto const [block_m, block_n_capped] = enforce_smem_limit(block_n == 96 ? 128 : 192, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n_capped, false, true}; } else if (headdim <= 128) { - bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; - return {128, use_blockN_128 ? 128 : 176, true, true}; + // Shared memory on consumer parts tops out at ~100KB, so prefer a BlockM=64 path that stays under that limit while + // keeping BlockN as large as possible for throughput. + int const block_n = paged_kv_non_TMA || is_local ? 80 : (headdim_v <= 128 ? 96 : 80); + auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n_capped, true, true}; // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { - return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem + // The 128x128 / 128x112 tiles exceed the ~100KB shared memory limit of consumer GPUs (for example, when running on + // devices without the larger H100 shared memory carve‑out). Use smaller tiles for all value dims to guarantee we + // stay below the per-block cap across head dimensions up to 192. + int const block_n = paged_kv_non_TMA || is_local ? 64 : (headdim <= 160 ? 80 : 64); + auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n_capped, true, true}; } else { - return {128, is_local ? 64 : 80, true, true}; // 128 x 80 hits the limit of smem + // For head dims above 192 the shared-memory footprint grows quickly with BlockM, so stick to 64xN tiles even though + // they are smaller than the H100-optimized 128xN shapes. Favor narrower BlockN when value dims are large to stay + // under the ~100KB cap on consumer GPUs. + int const block_n = paged_kv_non_TMA || is_local ? 48 : (headdim <= 256 ? 64 : 48); + auto const [block_m, block_n_capped] = enforce_smem_limit(64, block_n, headdim, headdim_v, element_size, kSm120ConsumerSmemLimit); + return {block_m, block_n_capped, true, true}; } } else { if (headdim <= 64) { diff --git a/tests/hopper/test_tile_size_shared_memory.py b/tests/hopper/test_tile_size_shared_memory.py new file mode 100644 index 00000000000..1ddddd0524b --- /dev/null +++ b/tests/hopper/test_tile_size_shared_memory.py @@ -0,0 +1,103 @@ +import ctypes +import os +import subprocess +import sys +import tempfile +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[2] +SMEM_LIMIT_BYTES = 101_376 + + +def _build_tile_size_bridge(tmpdir: Path) -> Path: + src = tmpdir / "tile_size_bridge.cpp" + lib = tmpdir / "libtile_size_bridge.so" + src.write_text( + r''' +#include +#include "hopper/tile_size.h" + +extern "C" void tile_size_fwd_sm90_bridge( + int headdim, int headdim_v, bool is_causal, bool is_local, int element_size, + bool v_colmajor, bool paged_kv_non_TMA, bool softcap, + int* block_m, int* block_n) { + auto result = tile_size_fwd_sm90(headdim, headdim_v, is_causal, is_local, + element_size, v_colmajor, paged_kv_non_TMA, softcap); + *block_m = std::get<0>(result); + *block_n = std::get<1>(result); +} +''', + encoding="utf-8", + ) + compile_cmd = [ + "g++", + "-std=c++20", + "-fPIC", + "-shared", + "-O2", + f"-I{REPO_ROOT}", + str(src), + "-o", + str(lib), + ] + subprocess.run(compile_cmd, check=True) + return lib + + +def _load_bridge() -> ctypes.CDLL: + with tempfile.TemporaryDirectory() as td: + lib = _build_tile_size_bridge(Path(td)) + return ctypes.CDLL(str(lib)) + + +def estimate_smem_bytes(block_m: int, block_n: int, headdim: int, headdim_v: int, element_size: int) -> int: + # Mirror the double-buffer estimate used in hopper/tile_size.h. + return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size + + +def test_tile_sizes_stay_within_blackwell_smem_budget(): + bridge = _load_bridge() + bridge.tile_size_fwd_sm90_bridge.argtypes = [ + ctypes.c_int, + ctypes.c_int, + ctypes.c_bool, + ctypes.c_bool, + ctypes.c_int, + ctypes.c_bool, + ctypes.c_bool, + ctypes.c_bool, + ctypes.POINTER(ctypes.c_int), + ctypes.POINTER(ctypes.c_int), + ] + + head_dims = (64, 96, 128, 160, 192, 256, 320) + value_dims = (64, 96, 128, 160, 192, 256, 512) + bools = (False, True) + + for headdim in head_dims: + for headdim_v in value_dims: + for is_causal in bools: + for is_local in bools: + if is_causal and is_local: + continue # invalid combination + for paged_kv_non_tma in bools: + block_m = ctypes.c_int() + block_n = ctypes.c_int() + bridge.tile_size_fwd_sm90_bridge( + headdim, + headdim_v, + is_causal, + is_local, + 2, # fp16/bf16 element size + False, # v_colmajor + paged_kv_non_tma, + False, # softcap + ctypes.byref(block_m), + ctypes.byref(block_n), + ) + smem_bytes = estimate_smem_bytes(block_m.value, block_n.value, headdim, headdim_v, 2) + assert smem_bytes <= SMEM_LIMIT_BYTES, ( + f"SMEM overrun for d={headdim}, dv={headdim_v}, causal={is_causal}, " + f"local={is_local}, paged={paged_kv_non_tma}: {smem_bytes}B" + )