Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 18 additions & 56 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_gqa_mla_thd():
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test requires the CK fused attention backend.")

test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False)
test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True)

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_dot_product_mem_calc():
Expand Down Expand Up @@ -179,9 +179,8 @@ def test_dot_product_mem_calc():
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
@pytest.mark.parametrize("share_cu_seqlens_ref", [False])
def test_dot_product_attention(
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs, share_cu_seqlens_ref
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
):
"""Test DotProductAttention module"""

Expand Down Expand Up @@ -269,7 +268,6 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
share_cu_seqlens_ref,
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
Expand All @@ -284,7 +282,6 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
share_cu_seqlens_ref, # Not used by AOT
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
Expand All @@ -300,7 +297,6 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
share_cu_seqlens_ref,
)
if IS_HIP_EXTENSION:
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
Expand All @@ -314,7 +310,6 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
share_cu_seqlens_ref,
)


Expand All @@ -329,7 +324,6 @@ def test_dot_product_attention(
workspace_opt,
pad_between_seqs,
is_training,
share_cu_seqlens_ref,
)

logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
Expand Down Expand Up @@ -366,7 +360,7 @@ def test_dot_product_attention(
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mla = {
Expand Down Expand Up @@ -395,7 +389,7 @@ def test_dpa_checkpoint(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_mla.keys())
def test_dpa_mla(dtype, model_configs, model):
"""Test DotProductAttention module with Multi-Latent Attention (MLA)"""
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False)


model_configs_mask = {
Expand Down Expand Up @@ -450,7 +444,7 @@ def test_dpa_mla(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_mask.keys())
def test_dpa_mask(dtype, model_configs, model):
"""Test DotProductAttention module with different mask types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_bias = {
Expand Down Expand Up @@ -560,7 +554,7 @@ def test_dpa_mask(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_bias.keys())
def test_dpa_bias(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_bias_shapes = {
Expand Down Expand Up @@ -598,7 +592,7 @@ def test_dpa_bias(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
def test_dpa_bias_shapes(dtype, model_configs, model):
"""Test DotProductAttention module with different bias types and shapes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


model_configs_swa = {
Expand Down Expand Up @@ -638,7 +632,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_swa.keys())
def test_dpa_sliding_window(dtype, model_configs, model):
"""Test DotProductAttention module with sliding window attention"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False)


model_configs_alibi_slopes = {
Expand Down Expand Up @@ -678,7 +672,7 @@ def test_dpa_sliding_window(dtype, model_configs, model):
@pytest.mark.parametrize("model", model_configs_alibi_slopes.keys())
def test_dpa_alibi_slopes(dtype, model_configs, model):
"""Test DotProductAttention module with ALiBi slopes"""
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False)


qkv_layouts = [
Expand Down Expand Up @@ -739,7 +733,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
@pytest.mark.parametrize("qkv_layout", qkv_layouts)
def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False, False)
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)


qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
Expand Down Expand Up @@ -797,8 +791,6 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
),
}

padding_configs = ([(True, False), (False, False), (False, True)] if IS_HIP_EXTENSION
else [(True, False), (False, False)])

# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
Expand All @@ -810,50 +802,27 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
@pytest.mark.parametrize(("pad_between_seqs", "share_cu_seqlens_ref"), padding_configs)
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref):
@pytest.mark.parametrize("pad_between_seqs", [True, False])
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between_seqs):
"""Test DotProductAttention module with different QKV layouts"""
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
if (pad_between_seqs==False and get_cudnn_version() < (9, 3, 0)):
pytest.skip("cuDNN 9.3.0+ is required to run pad_between_seqs = False");

if share_cu_seqlens_ref: #ROCm specific config
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test is only required for the CK fused attention backend.")

test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@pytest.mark.parametrize("qkv_layout", qkv_layouts_thd)
@pytest.mark.parametrize(("pad_between_seqs", "share_cu_seqlens_ref"), padding_configs)
def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref):
@pytest.mark.parametrize("pad_between_seqs", [True, False])
def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs):
config = model_configs[model]

if share_cu_seqlens_ref:
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test is only required for the CK fused attention backend.")

def find_factors(x):
f = []
for i in range(2, x + 1):
Expand All @@ -866,7 +835,7 @@ def find_factors(x):
for num_q_per_gqa_group in num_querys_per_gqa_group:
config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)


Expand All @@ -879,7 +848,6 @@ def _run_dot_product_attention(
workspace_opt: bool,
pad_between_seqs: bool,
is_training: bool,
share_cu_seqlens_ref: bool = False,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""

Expand Down Expand Up @@ -1149,9 +1117,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
if not is_training:
block = block.eval()

cu_seqlens_q_padded = None
cu_seqlens_kv_padded = None

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
q = inp_orig[0]
Expand All @@ -1163,9 +1128,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
k = inp[1]
v = inp[2]
d_out = out_grad
if pad_between_seqs or not share_cu_seqlens_ref:
cu_seqlens_q_padded = cu_seqlens_q_after_pad
cu_seqlens_kv_padded = cu_seqlens_kv_after_pad
out = block(
q,
k,
Expand All @@ -1177,8 +1139,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None,
cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=ckpt_attn,
core_attention_bias_type=config.attn_bias_type,
Expand Down