diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index 484019d661..eee9c20444 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -717,8 +717,10 @@ def _get_cache_key( input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, ) -> Tuple: - if hasattr(input_shapes, '__len__'): - shapes_tuple = tuple(tuple(s) if hasattr(s, '__iter__') else s for s in input_shapes) + if hasattr(input_shapes, "__len__"): + shapes_tuple = tuple( + tuple(s) if hasattr(s, "__iter__") else s for s in input_shapes + ) else: shapes_tuple = input_shapes return ( diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 8d657ac82f..ceafe8d053 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import paddle import functools from enum import IntEnum diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 4a49bfc4fe..66a7a3606a 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -562,9 +562,7 @@ def forward( # tgv_gemm takes mat1 as weights and mat2 as input tensor # from [m,k]x[k,n]+[n,] to [n,k]x[k,m]+[n,] gemm_fn = module.tgv_gemm - c = torch.empty( - (a.shape[0], b.shape[1]), dtype=a.dtype, device=a.place - ) + c = torch.empty((a.shape[0], b.shape[1]), dtype=a.dtype, device=a.place) gemm_fn(b.t(), a.t(), bias, tactic, c, pdl) return c @@ -2078,12 +2076,12 @@ def bmm_fp8( if out is None: out = torch.empty( (A.shape[0], A.shape[1], B.shape[2]), - device=a.place, + device=A.place, dtype=dtype, ) workspace_buffer = _get_cache_buf( - "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, a.place + "bmm_fp8_workspace", DEFAULT_WORKSPACE_SIZE, A.place ) if backend == "cudnn": diff --git a/tests/attention/test_attention_sink_blackwell.py b/tests/attention/test_attention_sink_blackwell.py index 09d666653c..19c379832c 100644 --- a/tests/attention/test_attention_sink_blackwell.py +++ b/tests/attention/test_attention_sink_blackwell.py @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. """ + import paddle + paddle.compat.enable_torch_proxy() import einops import pytest @@ -22,7 +24,7 @@ from tests.test_helpers.sink_attention_reference import sink_attention_unified import flashinfer -from flashinfer.utils import get_compute_capability +# from flashinfer.utils import get_compute_capability # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) diff --git a/tests/conftest.py b/tests/conftest.py index c1e050d0ea..58c44ab6b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Set import paddle + paddle.compat.enable_torch_proxy() import pytest import torch @@ -12,7 +13,7 @@ # from torch.torch_version import __version__ as torch_version import flashinfer -from flashinfer.jit import MissingJITCacheError +# from flashinfer.jit import MissingJITCacheError # Global tracking for JIT cache coverage # Store tuples of (test_name, module_name, spec_info) @@ -128,8 +129,8 @@ def wrapper(*args, **kwargs): def pytest_configure(config): if os.environ.get("FLASHINFER_TEST_TORCH_COMPILE", "0") == "1": - if torch_version < TorchVersion("2.4"): - pytest.skip("torch.compile requires torch >= 2.4") + # if torch_version < TorchVersion("2.4"): + # pytest.skip("torch.compile requires torch >= 2.4") _set_torch_compile_options() for fn in TORCH_COMPILE_FNS: _monkeypatch_add_torch_compile(fn) @@ -144,8 +145,8 @@ def pytest_runtest_call(item): # skip OOM error and missing JIT cache errors try: item.runtest() - except: - assert(False) + except Exception: + raise # try: # item.runtest() # except (torch.cuda.OutOfMemoryError, RuntimeError) as e: diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 190cddc92d..1543895b07 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. """ + import paddle + paddle.compat.enable_torch_proxy() import functools from typing import Tuple @@ -48,7 +50,9 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from flashinfer.utils import calculate_tile_tokens_dim, get_compute_capability + +# from flashinfer.utils import calculate_tile_tokens_dim, get_compute_capability +from flashinfer.utils import calculate_tile_tokens_dim @functools.cache @@ -2504,10 +2508,7 @@ def test_moe_quantization_classes( else: # Other routing methods (Renormalize, RenormalizeNaive, Llama4) use bfloat16 expert_logits = torch.randn((num_tokens, num_experts), device="cuda") - print("oringingin expert_logits:", expert_logits) - expert_logits = expert_logits.to( - torch.bfloat16 - ) + expert_logits = expert_logits.to(torch.bfloat16) # torch.set_printoptions(edgeitems=1000) # 显示更多边缘项 # torch.set_printoptions(linewidth=1000) # 增加每行宽度 print("expert_logits:", expert_logits) diff --git a/tests/test_helpers/sink_attention_reference.py b/tests/test_helpers/sink_attention_reference.py index 5527b45373..5995a44ed6 100644 --- a/tests/test_helpers/sink_attention_reference.py +++ b/tests/test_helpers/sink_attention_reference.py @@ -311,8 +311,9 @@ def sink_attention_unified( # mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( # 1 # ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) - mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze( - 1) >= torch.arange(0, kv_len).unsqueeze(0) + mask = torch.arange(kv_len - qo_len, kv_len).unsqueeze(1) >= torch.arange( + 0, kv_len + ).unsqueeze(0) if window_left >= 0: row_idx = torch.arange(qo_len, dtype=torch.int32, device=q.device)[ :, None