diff --git a/hopper/tile_size.h b/hopper/tile_size.h index 3bb67b7d56f..220ed4a3929 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -8,9 +8,9 @@ constexpr int smem_estimate_bytes(int block_m, int block_n, int headdim, int headdim_v, int element_size) { // Double-buffer the residency for Q/K/V and the accumulators when the head/value dims are modest. - // For very large dimensions the single-stage SM120 pipeline only needs a single residency footprint, so - // avoid over-clamping by scaling the estimate down to one buffer in that regime. - int const buffering = (headdim + headdim_v >= 512) ? 1 : 2; + // Value dimensions 256+ already drive the shared-memory footprint high, so treat them like the large + // combined-head/value case and drop to a single buffer in that regime to avoid over-clamping. + int const buffering = (headdim_v >= 256 || headdim + headdim_v >= 512) ? 1 : 2; return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size; } diff --git a/tests/hopper/test_tile_size_shared_memory.py b/tests/hopper/test_tile_size_shared_memory.py index 49e3fe24135..7cabd0856f9 100644 --- a/tests/hopper/test_tile_size_shared_memory.py +++ b/tests/hopper/test_tile_size_shared_memory.py @@ -53,7 +53,7 @@ def _load_bridge() -> ctypes.CDLL: def estimate_smem_bytes(block_m: int, block_n: int, headdim: int, headdim_v: int, element_size: int) -> int: # Mirror the buffering-aware estimate used in hopper/tile_size.h. - buffering = 1 if headdim + headdim_v >= 512 else 2 + buffering = 1 if (headdim_v >= 256 or headdim + headdim_v >= 512) else 2 return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size