From 042a8137eda7e592bc4e7b3bf25afbfd52e139e5 Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Thu, 22 Jan 2026 20:53:56 +0000 Subject: [PATCH 1/2] Clean up testing of outdated behavior --- tests/pytorch/attention/test_attention.py | 66 ++++++----------------- 1 file changed, 17 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a5128653e..f08f0a15b 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -127,7 +127,7 @@ def test_gqa_mla_thd(): if FusedAttnBackend["CK"] not in fused_attn_backends: pytest.skip("This test requires the CK fused attention backend.") - test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True) @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): @@ -179,9 +179,8 @@ def test_dot_product_mem_calc(): @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) @pytest.mark.parametrize("pad_between_seqs", [False]) -@pytest.mark.parametrize("share_cu_seqlens_ref", [False]) def test_dot_product_attention( - dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs, share_cu_seqlens_ref + dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs ): """Test DotProductAttention module""" @@ -269,7 +268,6 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, - share_cu_seqlens_ref, ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" @@ -284,7 +282,6 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, - share_cu_seqlens_ref, # Not used by AOT ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" os.environ["NVTE_FUSED_ATTN_CK"] = "1" @@ -300,7 +297,6 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, - share_cu_seqlens_ref, ) if IS_HIP_EXTENSION: os.environ["NVTE_CK_USES_FWD_V3"] = "0" @@ -314,7 +310,6 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, - share_cu_seqlens_ref, ) @@ -329,7 +324,6 @@ def test_dot_product_attention( workspace_opt, pad_between_seqs, is_training, - share_cu_seqlens_ref, ) logging.info(f"[test_dot_product_attention]: is_training = {is_training}") @@ -366,7 +360,7 @@ def test_dot_product_attention( @pytest.mark.parametrize("model", ["base_1_1", "base_2_1"]) def test_dpa_checkpoint(dtype, model_configs, model): """Test DotProductAttention module with checkpointing""" - test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) model_configs_mla = { @@ -395,7 +389,7 @@ def test_dpa_checkpoint(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_mla.keys()) def test_dpa_mla(dtype, model_configs, model): """Test DotProductAttention module with Multi-Latent Attention (MLA)""" - test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) model_configs_mask = { @@ -450,7 +444,7 @@ def test_dpa_mla(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_mask.keys()) def test_dpa_mask(dtype, model_configs, model): """Test DotProductAttention module with different mask types""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) model_configs_bias = { @@ -560,7 +554,7 @@ def test_dpa_mask(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_bias.keys()) def test_dpa_bias(dtype, model_configs, model): """Test DotProductAttention module with different bias types""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) model_configs_bias_shapes = { @@ -598,7 +592,7 @@ def test_dpa_bias(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_bias_shapes.keys()) def test_dpa_bias_shapes(dtype, model_configs, model): """Test DotProductAttention module with different bias types and shapes""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) model_configs_swa = { @@ -638,7 +632,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_swa.keys()) def test_dpa_sliding_window(dtype, model_configs, model): """Test DotProductAttention module with sliding window attention""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, None, True, False) model_configs_alibi_slopes = { @@ -678,7 +672,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) def test_dpa_alibi_slopes(dtype, model_configs, model): """Test DotProductAttention module with ALiBi slopes""" - test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, None, False, False) qkv_layouts = [ @@ -739,7 +733,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): @pytest.mark.parametrize("qkv_layout", qkv_layouts) def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): """Test DotProductAttention module with different QKV layouts""" - test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False, False) + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] @@ -797,8 +791,6 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): ), } -padding_configs = ([(True, False), (False, False), (False, True)] if IS_HIP_EXTENSION - else [(True, False), (False, False)]) # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out @pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.") @@ -810,28 +802,16 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): @pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) @pytest.mark.parametrize("model", model_configs_layout_thd.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) -@pytest.mark.parametrize(("pad_between_seqs", "share_cu_seqlens_ref"), padding_configs) -def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref): +@pytest.mark.parametrize("pad_between_seqs", [True, False]) +def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between_seqs): """Test DotProductAttention module with different QKV layouts""" config = model_configs[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") if (pad_between_seqs==False and get_cudnn_version() < (9, 3, 0)): pytest.skip("cuDNN 9.3.0+ is required to run pad_between_seqs = False"); - - if share_cu_seqlens_ref: #ROCm specific config - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") - test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") @@ -839,21 +819,10 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between @pytest.mark.parametrize("model_configs", [model_configs_layout_thd]) @pytest.mark.parametrize("model", model_configs_layout_thd.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layouts_thd) -@pytest.mark.parametrize(("pad_between_seqs", "share_cu_seqlens_ref"), padding_configs) -def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref): +@pytest.mark.parametrize("pad_between_seqs", [True, False]) +def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs): config = model_configs[model] - if share_cu_seqlens_ref: - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") - def find_factors(x): f = [] for i in range(2, x + 1): @@ -866,7 +835,7 @@ def find_factors(x): for num_q_per_gqa_group in num_querys_per_gqa_group: config.num_gqa_groups = config.num_heads // num_q_per_gqa_group test_dot_product_attention( - dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref + dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) @@ -879,7 +848,6 @@ def _run_dot_product_attention( workspace_opt: bool, pad_between_seqs: bool, is_training: bool, - share_cu_seqlens_ref: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" @@ -1163,7 +1131,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: k = inp[1] v = inp[2] d_out = out_grad - if pad_between_seqs or not share_cu_seqlens_ref: + if pad_between_seqs: cu_seqlens_q_padded = cu_seqlens_q_after_pad cu_seqlens_kv_padded = cu_seqlens_kv_after_pad out = block( From 7b645af6e181d47baf360f3b91ca0e308fbb2e3c Mon Sep 17 00:00:00 2001 From: Meekail Zain Date: Wed, 28 Jan 2026 16:17:07 -0600 Subject: [PATCH 2/2] Minimize cumulative diff from upstream --- tests/pytorch/attention/test_attention.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index f08f0a15b..337bc1646 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1117,9 +1117,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: block = block.eval() - cu_seqlens_q_padded = None - cu_seqlens_kv_padded = None - # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: q = inp_orig[0] @@ -1131,9 +1128,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: k = inp[1] v = inp[2] d_out = out_grad - if pad_between_seqs: - cu_seqlens_q_padded = cu_seqlens_q_after_pad - cu_seqlens_kv_padded = cu_seqlens_kv_after_pad out = block( q, k, @@ -1145,8 +1139,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, + cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, + cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type,