Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions hopper/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
#include <tuple>

constexpr int smem_estimate_bytes(int block_m, int block_n, int headdim, int headdim_v, int element_size) {
// Double-buffer the residency for Q/K/V and the accumulators to reflect the large SMEM footprint observed in practice.
return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size;
// Double-buffer the residency for Q/K/V and the accumulators when the head/value dims are modest.
// For very large dimensions the single-stage SM120 pipeline only needs a single residency footprint, so
// avoid over-clamping by scaling the estimate down to one buffer in that regime.
int const buffering = (headdim + headdim_v >= 512) ? 1 : 2;
return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size;
}

constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int headdim_v,
Expand All @@ -17,12 +20,12 @@ constexpr int clamp_block_n_for_smem(int block_m, int block_n, int headdim, int
if (smem_usage <= smem_limit) {
return block_n;
}
// Keep the tile width aligned to 8 to match the granularity of our block shapes while allowing tight caps.
// Keep the tile width aligned to 16 to satisfy GMMA tile constraints while allowing tight caps.
int const denom = 2 * element_size * (headdim + headdim_v);
int max_block_n = denom > 0 ? smem_limit / denom - block_m : block_n;
if (max_block_n < 8) { max_block_n = 8; }
max_block_n = (max_block_n / 8) * 8;
return max_block_n > 0 ? max_block_n : 8;
if (max_block_n < 16) { max_block_n = 16; }
max_block_n = (max_block_n / 16) * 16;
return max_block_n > 0 ? max_block_n : 16;
}

constexpr std::tuple<int, int> enforce_smem_limit(int block_m, int block_n, int headdim, int headdim_v,
Expand All @@ -37,9 +40,9 @@ constexpr std::tuple<int, int> enforce_smem_limit(int block_m, int block_n, int
if (smem_usage > smem_limit) {
int const denom = 2 * element_size * (headdim + headdim_v);
int max_block_m = denom > 0 ? smem_limit / denom - adjusted_block_n : block_m;
if (max_block_m < 8) { max_block_m = 8; }
max_block_m = (max_block_m / 8) * 8;
block_m = max_block_m > 0 ? max_block_m : 8;
if (max_block_m < 16) { max_block_m = 16; }
max_block_m = (max_block_m / 16) * 16;
block_m = max_block_m > 0 ? max_block_m : 16;
adjusted_block_n = clamp_block_n_for_smem(block_m, adjusted_block_n, headdim, headdim_v, element_size, smem_limit);
}
return {block_m, adjusted_block_n};
Expand Down
5 changes: 3 additions & 2 deletions tests/hopper/test_tile_size_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def _load_bridge() -> ctypes.CDLL:


def estimate_smem_bytes(block_m: int, block_n: int, headdim: int, headdim_v: int, element_size: int) -> int:
# Mirror the double-buffer estimate used in hopper/tile_size.h.
return 2 * (block_m + block_n) * (headdim + headdim_v) * element_size
# Mirror the buffering-aware estimate used in hopper/tile_size.h.
buffering = 1 if headdim + headdim_v >= 512 else 2
return buffering * (block_m + block_n) * (headdim + headdim_v) * element_size


def test_tile_sizes_stay_within_blackwell_smem_budget():
Expand Down