diff --git a/tests/attention/test_attention_sink_blackwell.py b/tests/attention/test_attention_sink_blackwell.py index 104ca9cf8e..09d666653c 100644 --- a/tests/attention/test_attention_sink_blackwell.py +++ b/tests/attention/test_attention_sink_blackwell.py @@ -250,6 +250,6 @@ def test_blackwell_trtllm_gen_context_attention_sink( atol, rtol = 1e-2, 1e-2 else: raise ValueError(f"Unsupported dtype: {dtype}") - ref_o = o_ref.float().numpy() - np.testing.assert_allclose(ref_o, paddle_o, atol=atol, rtol=rtol) + output_o = output.float().numpy() + np.testing.assert_allclose(ref_o, output_o, atol=atol, rtol=rtol) diff --git a/tests/conftest.py b/tests/conftest.py index 1db8d85973..c1e050d0ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,7 +145,7 @@ def pytest_runtest_call(item): try: item.runtest() except: - # assert(False) + assert(False) # try: # item.runtest() # except (torch.cuda.OutOfMemoryError, RuntimeError) as e: diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index 1f8d1a13c4..5527b45373 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -312,8 +312,7 @@ def sink_attention_unified( # 1 # ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze( - 1 - ) >= torch.arange(0, kv_len).unsqueeze(0) + 1) >= torch.arange(0, kv_len).unsqueeze(0) if window_left >= 0: row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None @@ -343,10 +342,8 @@ def sink_attention_unified( mask &= abs_row_positions - window_left <= col_idx else: # Non-causal mask - q_device = q.place - print("+++",q_device) if mode == "incremental": - mask = torch.ones(1, kv_len, device=q_device, dtype=torch.bool) + mask = torch.ones(1, kv_len, dtype=torch.bool) if window_left >= 0: col_idx = torch.arange(kv_len, dtype=torch.int32, device=q.device) mask = (kv_len - 1 - window_left) <= col_idx