Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5250085
fp4 dense
zoooo0820 Oct 24, 2025
b0c863a
[WIP] support nvfp4, dense part
zoooo0820 Oct 27, 2025
d5f3fd2
[wip] developing loading qwen model
zoooo0820 Oct 28, 2025
1176cae
loading
bukejiyu Nov 5, 2025
7137054
update
bukejiyu Nov 6, 2025
0594090
dense fp4 OK, cudagraph error
zoooo0820 Nov 6, 2025
ae80853
[WIP] moe forward part
zoooo0820 Nov 7, 2025
6b2ebd6
with flashinfer-backend
zoooo0820 Nov 14, 2025
0b28b4b
qwen3_moe_fp4
bukejiyu Nov 17, 2025
2d2bd06
update
bukejiyu Nov 18, 2025
c329d92
support flashinfer-cutlass moe, qwen3-moe-fp4 OK
zoooo0820 Nov 18, 2025
eb089b3
support ernie4.5-fp4
zoooo0820 Nov 19, 2025
1931732
solve confilict
zoooo0820 Nov 19, 2025
03aa695
fix load error
zoooo0820 Nov 20, 2025
5233398
add some ut
zoooo0820 Nov 20, 2025
748e812
add docs
zoooo0820 Nov 20, 2025
3d38d73
Merge branch 'develop' into support_fp4_moe
zoooo0820 Dec 3, 2025
be11fc3
fix CLA, test
Echo-Nie Jan 12, 2026
e071d51
Merge remote-tracking branch 'zoooo/support_fp4_moe' into fp4_moe
Echo-Nie Jan 12, 2026
509fc33
fix the apply() in ModelOptNvFp4FusedMoE
Echo-Nie Jan 12, 2026
798cb6b
fix CodeStyle
Echo-Nie Jan 12, 2026
17d0740
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 13, 2026
ca2a699
del the PADDLE_COMPATIBLE_API
Echo-Nie Jan 13, 2026
359b6b6
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 13, 2026
14fc296
fix broken url: nvidia_gpu.md
Echo-Nie Jan 13, 2026
a25fea0
fix docs
Echo-Nie Jan 13, 2026
d93cdb5
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 14, 2026
88c8347
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 14, 2026
14bbd6b
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 15, 2026
d7426fd
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 15, 2026
b3e600d
fix token_ids
Echo-Nie Jan 19, 2026
d9f8a74
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 19, 2026
ee8f622
fix CI in Hopper
Echo-Nie Jan 19, 2026
4057e1e
move flashinfer imports inside the function
Echo-Nie Jan 20, 2026
0faf0c2
Merge branch 'PaddlePaddle:develop' into fp4_moe
Echo-Nie Jan 20, 2026
f9ec344
fix model_runner
Echo-Nie Jan 20, 2026
ce6d40f
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 21, 2026
fb71cca
Remove skip condition for CUDA version in nvfp4 test
Echo-Nie Jan 22, 2026
a2fa9ff
Merge branch 'develop' into fp4_moe
Echo-Nie Jan 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/quantization/nvfp4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

# NVFP4 Quantization
NVFP4 is an innovative 4-bit floating-point format introduced by NVIDIA. For detailed information, please refer to [Introducing NVFP4 for Efficient and Accurate Low-Precision Inference](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/).

Based on [FlashInfer](https://github.com/flashinfer-ai/flashinfer), Fastdeploy supports NVFP4 quantized model inference in the format produced by [Modelopt](https://github.com/NVIDIA/TensorRT-Model-Optimizer).

- Note: Currently, this feature only supports FP4 quantized models of Ernie/Qwen series.

## How to Use
### Environment Setup
#### Supported Environment
- **Supported Hardware**: GPU sm >= 100
- **PaddlePaddle Version**: 3.3.0 or higher
- **Fastdeploy Version**: 2.5.0 or higher

#### FastDeploy Installation
Please ensure that FastDeploy is installed with NVIDIA GPU support.
Follow the official guide to set up the base environment: [Fastdeploy NVIDIA GPU Environment Installation Guide](https://paddlepaddle.github.io/FastDeploy/get_started/installation/nvidia_gpu/).

### Running Inference Service
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model nv-community/Qwen3-30B-A3B-FP4 \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--cache-queue-port 8183 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128
```

### API Access
Make service requests using the following command

```shell
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "把李白的静夜思改写为现代诗"}
]
}'
```

FastDeploy service interface is compatible with OpenAI protocol. You can make service requests using the following Python code.

```python
import openai
host = "0.0.0.0"
port = "8180"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")

response = client.chat.completions.create(
model="null",
messages=[
{"role": "system", "content": "I'm a helpful AI assistant."},
{"role": "user", "content": "把李白的静夜思改写为现代诗"},
],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```.
66 changes: 66 additions & 0 deletions docs/zh/quantization/nvfp4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
[English](../../quantization/nvfp4.md)

# NVFP4量化
NVFP4 是 NVIDIA 引入的创新 4 位浮点格式,详细介绍请参考[Introducing NVFP4 for Efficient and Accurate Low-Precision Inference](https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/)。

基于[FlashInfer](https://github.com/flashinfer-ai/flashinfer), Fastdeploy 支持[Modelopt](https://github.com/NVIDIA/TensorRT-Model-Optimizer) 产出格式的NVFP4量化模型推理。

- 注:目前该功能仅支持Ernie / Qwen系列的FP4量化模型。

## 如何使用
### 环境准备
#### 支持环境
- **支持硬件**:GPU sm >= 100
- **PaddlePaddle 版本**:3.3.0 或更高版本
- **Fastdeploy 版本**:2.5.0 或更高版本

#### Fastdeploy 安装
FastDeploy 需以 NVIDIA GPU 模式安装,具体安装方式请参考官方文档:[Fastdeploy NVIDIA GPU 环境安装指南](https://paddlepaddle.github.io/FastDeploy/zh/get_started/installation/nvidia_gpu/)。

### 运行推理服务
```bash
python -m fastdeploy.entrypoints.openai.api_server \
--model nv-community/Qwen3-30B-A3B-FP4 \
--port 8180 \
--metrics-port 8181 \
--engine-worker-queue-port 8182 \
--cache-queue-port 8183 \
--tensor-parallel-size 1 \
--max-model-len 32768 \
--max-num-seqs 128
```

### 接口访问
通过如下命令发起服务请求

```shell
curl -X POST "http://0.0.0.0:8180/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user", "content": "把李白的静夜思改写为现代诗"}
]
}'
```

FastDeploy服务接口兼容OpenAI协议,可以通过如下Python代码发起服务请求。

```python
import openai
host = "0.0.0.0"
port = "8180"
client = openai.Client(base_url=f"http://{host}:{port}/v1", api_key="null")

response = client.chat.completions.create(
model="null",
messages=[
{"role": "system", "content": "I'm a helpful AI assistant."},
{"role": "user", "content": "把李白的静夜思改写为现代诗"},
],
stream=True,
)
for chunk in response:
if chunk.choices[0].delta:
print(chunk.choices[0].delta.content, end='')
print('\n')
```
4 changes: 4 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@
"FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR": lambda: os.getenv("FD_LIMIT_THINKING_CONTENT_TRUNCATE_STR", "</think>"),
# Timeout for cache_transfer_manager process exit
"FD_CACHE_PROC_EXIT_TIMEOUT": lambda: int(os.getenv("FD_CACHE_PROC_EXIT_TIMEOUT", "600")),
# FP4 dense GEMM backend, could be flashinfer-cutlass, flashinfer-trtllm, flashinfer-cudnn or None (default is None)
"FD_NVFP4_GEMM_BACKEND": lambda: os.getenv("FD_NVFP4_MOE_BACKEND", None),
# Flahinfer MOE backend, could be flashinfer-cutlass, flashinfer-trtllm or None (default is None)
"FD_FLASHINFER_MOE_BACKEND": lambda: os.getenv("FD_FLASHINFER_MOE_BACKEND", None),
# Count for cache_transfer_manager process error
"FD_CACHE_PROC_ERROR_COUNT": lambda: int(os.getenv("FD_CACHE_PROC_ERROR_COUNT", "10")),
# API_KEY required for service authentication
Expand Down
33 changes: 33 additions & 0 deletions fastdeploy/flashinfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import functools
import importlib
import importlib.util
import shutil


@functools.cache
def has_flashinfer() -> bool:
"""Return `True` if FlashInfer is available."""
# Use find_spec to check if the module exists without importing it
# This avoids potential CUDA initialization side effects
if importlib.util.find_spec("flashinfer") is None:
return False
# Also check if nvcc is available since it's required to JIT compile flashinfer
if shutil.which("nvcc") is None:
return False
return True
41 changes: 37 additions & 4 deletions fastdeploy/model_executor/layers/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,10 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
expert_param = param[expert_id - self.expert_id_offset]
dim = -1 if shard_dim else 0
param_shard_size = expert_param.shape[dim] // 2
if shard_id == "gate":
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (shard_id == "gate" and not switch_w13) or (shard_id == "up" and switch_w13):
param_shard_offset = 0
else:
# shard_id == "up":
param_shard_offset = param_shard_size
expert_param = slice_fn(
expert_param, shard_dim, start=param_shard_offset, end=param_shard_offset + param_shard_size
Expand All @@ -342,8 +342,12 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
)

# To ensure compatibility across backends, apply an extra transpose for GCU and XPU

if expert_param.shape != loaded_weight.shape:
loaded_weight = loaded_weight.transpose([1, 0])
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
assert expert_param.shape == loaded_weight.shape, (
f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})"
)
Expand Down Expand Up @@ -402,6 +406,32 @@ def _load_fused_experts_weight(self, param, loaded_weight):
for i in range(self.num_local_experts):
param.tensor_track.mark(start=0, batch_id=i)

def _load_per_tensor_weight_scale(
self,
param,
expert_id,
loaded_weight,
shard_id,
):
loaded_weight = get_tensor(loaded_weight)
expert_param = param[expert_id - self.expert_id_offset]
if shard_id in ["gate", "up"]:
idx = 0 if shard_id == "gate" else 1
if expert_param[idx].shape != loaded_weight.shape:
if len(expert_param[idx].shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param[idx].shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])

expert_param[idx].set_value(loaded_weight)
elif shard_id == "down":
if expert_param.shape != loaded_weight.shape:
if len(expert_param.shape) != len(loaded_weight.shape):
loaded_weight = loaded_weight.reshape(expert_param.shape)
else:
loaded_weight = loaded_weight.transpose([1, 0])
expert_param.set_value(loaded_weight)

def _load_expert_weight(
self,
param,
Expand All @@ -410,7 +440,10 @@ def _load_expert_weight(
shard_id,
shard_dim=None,
):
if shard_id == "down":
weight_type = getattr(param, "weight_type", None)
if weight_type in ["weight_scale_2", "input_scale"]:
self._load_per_tensor_weight_scale(param, expert_id, loaded_weight, shard_id)
elif shard_id == "down":
self._load_down_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
elif shard_id in ["gate", "up"]:
self._load_gate_up_weight(param, expert_id, loaded_weight, shard_id, shard_dim)
Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"mix_quant",
"tensor_wise_fp8",
"kvcache",
"modelopt_fp4",
"mxfp4",
]

Expand Down Expand Up @@ -133,6 +134,11 @@ def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_l
has_block_size = "weight_block_size" in quantization_config
if quant_method == "fp8" and has_block_size:
quant_config_name = "block_wise_fp8"
elif quant_method == "modelopt":
if quantization_config.get("quant_algo", "") == "NVFP4":
quant_config_name = "modelopt_fp4"
else:
raise ValueError("modelopt only supports NVFP4 quantization.")
elif quant_method == "mxfp4":
quant_config_name = "mxfp4"
else:
Expand All @@ -152,6 +158,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
from .block_wise_fp8 import BlockWiseFP8Config
from .kv_cache import KvCacheQuantConfig
from .mix_quant import MixQuantConfig
from .nvfp4 import ModelOptNvFp4Config
from .tensor_wise_fp8 import TensorWiseFP8Config
from .w4a8 import W4A8Config
from .w4afp8 import W4AFP8Config
Expand All @@ -176,6 +183,7 @@ def get_quantization_config(quantization: str) -> Type[QuantConfigBase]:
"tensor_wise_fp8": TensorWiseFP8Config,
"kvcache": KvCacheQuantConfig,
"mix_quant": MixQuantConfig,
"modelopt_fp4": ModelOptNvFp4Config,
}
if envs.FD_MOE_MXFP4_BACKEND is not None:
method_to_config["mxfp4"] = MXFP4Config
Expand Down
Loading
Loading