diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 6039827d479..3bb67b7d56f 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -7,8 +7,11 @@ #include constexpr int smem_estimate_bytes(int block_m, int block_n, int headdim, int headdim_v, int element_size) { - // Double-buffer the residency for Q/K/V and the accumulators to reflect the large SMEM footprint observed in practice. - return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size; + // Double-buffer the residency for Q/K/V and the accumulators when the head/value dims are modest. + // For very large dimensions the single-stage SM120 pipeline only needs a single residency footprint, so + // avoid over-clamping by scaling the estimate down to one buffer in that regime. + int const buffering = (headdim + headdim_v >= 512) ? 1 : 2; + return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size; } constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int headdim_v, @@ -17,12 +20,12 @@ constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int if (smem_usage <= smem_limit) { return block_n; } - // Keep the tile width aligned to 8 to match the granularity of our block shapes while allowing tight caps. + // Keep the tile width aligned to 16 to satisfy GMMA tile constraints while allowing tight caps. int const denom = 2 * element_size * (headdim + headdim_v); int max_block_n = denom > 0 ? smem_limit / denom - block_m : block_n; - if (max_block_n < 8) { max_block_n = 8; } - max_block_n = (max_block_n / 8) * 8; - return max_block_n > 0 ? max_block_n : 8; + if (max_block_n < 16) { max_block_n = 16; } + max_block_n = (max_block_n / 16) * 16; + return max_block_n > 0 ? max_block_n : 16; } constexpr std::tuple enforce_smem_limit(int block_m, int block_n, int headdim, int headdim_v, @@ -37,9 +40,9 @@ constexpr std::tuple enforce_smem_limit(int block_m, int block_n, int if (smem_usage > smem_limit) { int const denom = 2 * element_size * (headdim + headdim_v); int max_block_m = denom > 0 ? smem_limit / denom - adjusted_block_n : block_m; - if (max_block_m < 8) { max_block_m = 8; } - max_block_m = (max_block_m / 8) * 8; - block_m = max_block_m > 0 ? max_block_m : 8; + if (max_block_m < 16) { max_block_m = 16; } + max_block_m = (max_block_m / 16) * 16; + block_m = max_block_m > 0 ? max_block_m : 16; adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit); } return {block_m, adjusted_block_n}; diff --git a/tests/hopper/test_tile_size_shared_memory.py b/tests/hopper/test_tile_size_shared_memory.py index 1ddddd0524b..49e3fe24135 100644 --- a/tests/hopper/test_tile_size_shared_memory.py +++ b/tests/hopper/test_tile_size_shared_memory.py @@ -52,8 +52,9 @@ def _load_bridge() -> ctypes.CDLL: def estimate_smem_bytes(block_m: int, block_n: int, headdim: int, headdim_v: int, element_size: int) -> int: - # Mirror the double-buffer estimate used in hopper/tile_size.h. - return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size + # Mirror the buffering-aware estimate used in hopper/tile_size.h. + buffering = 1 if headdim + headdim_v >= 512 else 2 + return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size def test_tile_sizes_stay_within_blackwell_smem_budget():