diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..dc0cc7cbc --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,3 @@ +.github/CODEOWNERS @fzyzcjy @Ying1123 +.github/workflows/ @yushengsu-thu +/miles/ @fzyzcjy @yueming-yuan diff --git a/build_conda.sh b/build_conda.sh index ac1bb2091..46fc12df6 100644 --- a/build_conda.sh +++ b/build_conda.sh @@ -21,13 +21,13 @@ micromamba install -n miles cuda cuda-nvtx cuda-nvtx-dev nccl -c nvidia/label/cu micromamba install -n miles -c conda-forge cudnn -y # prevent installing cuda 13.0 for sglang -pip install cuda-python==12.9.1 -pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129 +pip install cuda-python==13.1.0 +pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu129 # install sglang git clone https://github.com/sgl-project/sglang.git cd sglang -git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +git checkout 5e2cda6158e670e64b926a9985d65826c537ac82 # Install the python packages pip install -e "python[all]" @@ -39,7 +39,7 @@ pip install cmake ninja MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" +pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0" pip install flash-linear-attention==0.4.0 NVCC_APPEND_FLAGS="--threads 4" \ pip -v install --disable-pip-version-check --no-cache-dir \ @@ -50,7 +50,7 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ pip install -e . -pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation @@ -60,6 +60,9 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM/ && git checkout core_v0.14.0 && \ pip install -e . +# https://github.com/pytorch/pytorch/issues/168167 +pip install nvidia-cudnn-cu12==9.16.0.29 + # install miles and apply patches # if miles does not exist locally, clone it @@ -76,6 +79,6 @@ fi # apply patch cd $BASE_DIR/sglang -git apply $MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch +git apply $MILES_DIR/docker/patch/v0.5.6/sglang.patch cd $BASE_DIR/Megatron-LM -git apply $MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch +git apply $MILES_DIR/docker/patch/v0.5.6/megatron.patch \ No newline at end of file diff --git a/build_uv_berlin.sh b/build_uv_berlin.sh new file mode 100644 index 000000000..c4f6cd7cd --- /dev/null +++ b/build_uv_berlin.sh @@ -0,0 +1,194 @@ +#!/bin/bash + +# ============================================================================= +# Miles Build Script (CUDA 12.8 Slurm Version) +# +# This script uses uv for Python environment management. +# It relies on the Slurm module system for CUDA 12.8 instead of pip packages. +# +# Configuration: +# - CUDA: 12.8 (via module load) +# - PyTorch: 2.8.0 (cu128) +# - Flash Attention 3: Prebuilt for cu128 + torch2.8 +# - Flash Attention 2: Prebuilt for cu128 + torch2.8 +# ============================================================================= + +set -e # Exit on error + +BASE_DIR="$(pwd)/.." # change this if you want a different base directory. Default is parent directory of miles + +if [ -z "$BASE_DIR" ]; then + echo "BASE_DIR is not set. Please set it to proceed with the installation." + exit 1 +fi + +# ============================================================================= +# Load Slurm Module +# ============================================================================= +echo "Loading CUDA 12.8 module..." +module load CUDA/12.8 + +# Verify NVCC is in path +if ! command -v nvcc &> /dev/null; then + echo "CRITICAL ERROR: nvcc not found after loading module." + echo "Please ensure 'module load CUDA/12.8' works on this cluster." + exit 1 +fi + +# ============================================================================= +# Install uv if not already installed +# ============================================================================= +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# ============================================================================= +# Create Python virtual environment with uv +# ============================================================================= +# Create virtual environment with Python 3.12 +uv venv --python 3.12 + +# Activate the virtual environment +source ".venv/bin/activate" + +cd "$BASE_DIR" +# ============================================================================= +# Install PyTorch with CUDA 12.8 +# ============================================================================= +echo "Installing PyTorch 2.8.0 with CUDA 12.8..." + +# Install cuda-python (Pinned to 12.8 to match the module) +uv pip install cuda-python==12.8.0 + +# Install PyTorch 2.8.0 for CUDA 12.8 +uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128 + +# Set TORCH_CUDA_ARCH_LIST for our GPU architectures +# 8.0 = A100 (Ampere), 9.0 = H100 (Hopper) +export TORCH_CUDA_ARCH_LIST="8.0;9.0" + +# ============================================================================= +# Install sglang +# ============================================================================= +echo "Installing sglang..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/sglang" ]; then + git clone https://github.com/sgl-project/sglang.git +fi +cd sglang +git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +uv pip install -e "python[all]" + +# ============================================================================= +# Install build tools +# ============================================================================= +uv pip install cmake ninja packaging build wheel + +# ============================================================================= +# Install Flash Attention 3 (prebuilt wheels for cu128 + torch2.8) +# ============================================================================= +echo "Installing Flash Attention 3..." +# Using windreamer's wheel index for CUDA 12.8 and PyTorch 2.8.0 +uv pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280 --extra-index-url https://download.pytorch.org/whl/cu128 + +# ============================================================================= +# Install Flash Attention 2 (prebuilt wheel for Megatron compatibility) +# ============================================================================= +echo "Installing Flash Attention 2..." +uv pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.18/flash_attn-2.7.4%2Bcu128torch2.8-cp312-cp312-linux_x86_64.whl + +# ============================================================================= +# Install mbridge, transformer_engine, flash-linear-attention +# ============================================================================= +echo "Installing mbridge, transformer_engine, flash-linear-attention..." +uv pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +# Transformer Engine 2.8.0 (compatible with CUDA 12.8) +uv pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" --no-cache-dir + +uv pip install flash-linear-attention==0.4.0 + +# ============================================================================= +# Install NVIDIA Apex (requires CUDA compilation) +# ============================================================================= +echo "Installing NVIDIA Apex (Compiling from source)..." +# We utilize the loaded CUDA module for compilation +NVCC_APPEND_FLAGS="--threads 4" \ + APEX_CPP_EXT=1 \ + APEX_CUDA_EXT=1 \ + APEX_PARALLEL_BUILD=8 \ + uv pip install -v --no-cache-dir \ + --no-build-isolation \ + git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +# ============================================================================= +# Install Megatron-LM +# ============================================================================= +echo "Installing Megatron-LM..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/Megatron-LM" ]; then + git clone https://github.com/NVIDIA/Megatron-LM.git --recursive +fi +cd Megatron-LM +git checkout core_v0.14.0 +uv pip install -e . + +# ============================================================================= +# Install additional dependencies +# ============================================================================= +echo "Installing additional dependencies..." +uv pip install poetry pybind11 +uv pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +uv pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +uv pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation + +# ============================================================================= +# Install remaining packages +# ============================================================================= +uv pip install sglang_router ring_flash_attn pylatexenc +uv pip install -U "ray[data,train,tune,serve]" + +# ============================================================================= +# Install miles +# ============================================================================= +echo "Installing miles..." +if [ ! -d "$BASE_DIR/miles" ]; then + cd "$BASE_DIR" + git clone https://github.com/radixark/miles.git + cd miles/ + export MILES_DIR="$BASE_DIR/miles" + uv pip install -e . +elif [ -f "$BASE_DIR/pyproject.toml" ]; then + export MILES_DIR="$BASE_DIR" + cd "$MILES_DIR" + uv pip install -e . +else + export MILES_DIR="$BASE_DIR/miles" + cd "$MILES_DIR" + uv pip install -e . +fi + +# ============================================================================= +# Apply patches +# ============================================================================= +echo "Applying patches..." +cd "$BASE_DIR/sglang" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch" || echo "sglang patch already applied or failed" + +cd "$BASE_DIR/Megatron-LM" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch" || echo "Megatron patch already applied or failed" + +echo "" +echo "=============================================================================" +echo "Installation complete!" +echo "" +echo "To activate the environment, run:" +echo " module load CUDA/12.8" +echo " source $BASE_DIR/miles-venv/bin/activate" +echo "" +echo "Environment configured using System CUDA from 'module load CUDA/12.8'" +echo "PyTorch Version: 2.8.0 (cu128)" +echo "CUDA_HOME: $CUDA_HOME" +echo "=============================================================================" \ No newline at end of file diff --git a/build_uv_juelich.sh b/build_uv_juelich.sh new file mode 100644 index 000000000..9ba6830d2 --- /dev/null +++ b/build_uv_juelich.sh @@ -0,0 +1,198 @@ +#!/bin/bash + +# ============================================================================= +# Miles Build Script (CUDA 12.6 Slurm Version) +# +# This script uses uv for Python environment management. +# It relies on the Slurm module system for CUDA 12.6 instead of pip packages. +# +# Configuration: +# - CUDA: 12.6 (via module load) +# - PyTorch: 2.8.0 (cu126) +# - Flash Attention 3: Prebuilt for cu126 + torch2.8 +# - Flash Attention 2: Prebuilt for cu128 + torch2.8 +# ============================================================================= + +set -e # Exit on error + +BASE_DIR="$(pwd)/.." # change this if you want a different base directory. Default is parent directory of miles + +if [ -z "$BASE_DIR" ]; then + echo "BASE_DIR is not set. Please set it to proceed with the installation." + exit 1 +fi + +# ============================================================================= +# Load Slurm Module +# ============================================================================= +echo "Loading CUDA 12.6 module..." +module load CUDA/12 +module load cuDNN/9.5.0.50-CUDA-12 +module load NCCL/default-CUDA-12 +module load Clang/18.1.8 + +# Verify NVCC is in path +if ! command -v nvcc &> /dev/null; then + echo "CRITICAL ERROR: nvcc not found after loading module." + echo "Please ensure 'module load CUDA/12' works on this cluster." + exit 1 +fi + +# ============================================================================= +# Install uv if not already installed +# ============================================================================= +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# ============================================================================= +# Create Python virtual environment with uv +# ============================================================================= +# Create virtual environment with Python 3.12 +uv venv --python 3.12 + +# Activate the virtual environment +source ".venv/bin/activate" + +cd "$BASE_DIR" +# ============================================================================= +# Install PyTorch with CUDA 12.6 +# ============================================================================= +echo "Installing PyTorch 2.8.0 with CUDA 12.6..." + +# Install cuda-python (Pinned to 12.6 to match the module) +uv pip install cuda-python==12.6.0 + +# Install PyTorch 2.8.0 for CUDA 12.6 +uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu126 + +# Set TORCH_CUDA_ARCH_LIST for our GPU architectures +# 8.0 = A100 (Ampere), 9.0 = H100 (Hopper) +export TORCH_CUDA_ARCH_LIST="8.0;9.0" + +# ============================================================================= +# Install sglang +# ============================================================================= +echo "Installing sglang..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/sglang" ]; then + git clone https://github.com/sgl-project/sglang.git +fi +cd sglang +git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +uv pip install -e "python[all]" + +# ============================================================================= +# Install build tools +# ============================================================================= +uv pip install cmake ninja packaging build wheel + +# ============================================================================= +# Install Flash Attention 3 (prebuilt wheels for cu126 + torch2.8) +# ============================================================================= +echo "Installing Flash Attention 3..." +# Using windreamer's wheel index for CUDA 12.6 and PyTorch 2.8.0 +uv pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu126_torch280 --extra-index-url https://download.pytorch.org/whl/cu126 +# ============================================================================= +# Install Flash Attention 2 (prebuilt wheel for Megatron compatibility) +# ============================================================================= +echo "Installing Flash Attention 2..." +uv pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.8-cp312-cp312-linux_x86_64.whl + +# ============================================================================= +# Install mbridge, transformer_engine, flash-linear-attention +# ============================================================================= +echo "Installing mbridge, transformer_engine, flash-linear-attention..." +uv pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +# Transformer Engine 2.8.0 (compatible with CUDA 12.8) +uv pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" --no-cache-dir + +uv pip install flash-linear-attention==0.4.0 + +# ============================================================================= +# Install NVIDIA Apex (requires CUDA compilation) +# ============================================================================= +echo "Installing NVIDIA Apex (Compiling from source)..." +# We utilize the loaded CUDA module for compilation +NVCC_APPEND_FLAGS="--threads 4" \ + APEX_CPP_EXT=1 \ + APEX_CUDA_EXT=1 \ + APEX_PARALLEL_BUILD=8 \ + uv pip install -v --no-cache-dir \ + --no-build-isolation \ + git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +# ============================================================================= +# Install Megatron-LM +# ============================================================================= +echo "Installing Megatron-LM..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/Megatron-LM" ]; then + git clone https://github.com/NVIDIA/Megatron-LM.git --recursive +fi +cd Megatron-LM +git checkout core_v0.14.0 +uv pip install -e . + +# ============================================================================= +# Install additional dependencies +# ============================================================================= +echo "Installing additional dependencies..." +uv pip install poetry pybind11 +uv pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +uv pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +uv pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation + +# ============================================================================= +# Install remaining packages +# ============================================================================= +uv pip install sglang_router ring_flash_attn pylatexenc +uv pip install -U "ray[data,train,tune,serve]" + +# ============================================================================= +# Install miles +# ============================================================================= +echo "Installing miles..." +if [ ! -d "$BASE_DIR/miles" ]; then + cd "$BASE_DIR" + git clone https://github.com/radixark/miles.git + cd miles/ + export MILES_DIR="$BASE_DIR/miles" + uv pip install -e . +elif [ -f "$BASE_DIR/pyproject.toml" ]; then + export MILES_DIR="$BASE_DIR" + cd "$MILES_DIR" + uv pip install -e . +else + export MILES_DIR="$BASE_DIR/miles" + cd "$MILES_DIR" + uv pip install -e . +fi + +# ============================================================================= +# Apply patches +# ============================================================================= +echo "Applying patches..." +cd "$BASE_DIR/sglang" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch" || echo "sglang patch already applied or failed" + +cd "$BASE_DIR/Megatron-LM" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch" || echo "Megatron patch already applied or failed" + +echo "" +echo "=============================================================================" +echo "Installation complete!" +echo "" +echo "To activate the environment, run:" +echo " module load CUDA/12" +echo " module load cuDNN/9.5.0.50-CUDA-12" +echo " module load NCCL/default-CUDA-12" +echo " source $BASE_DIR/venv/bin/activate" +echo "" +echo "Environment configured using System CUDA from 'module load CUDA/12'" +echo "PyTorch Version: 2.8.0 (cu126)" +echo "CUDA_HOME: $CUDA_HOME" +echo "=============================================================================" \ No newline at end of file diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 new file mode 100644 index 000000000..6dc1353f0 --- /dev/null +++ b/docker/Dockerfile.rocm_MI350-5 @@ -0,0 +1,252 @@ +#### Use the base image for ROCm 7 / gfx950 (MI355) + +# The Docker image built with this Dockerfile: +# Base image: ROCm 7 with vllm pre-built for gfx950 +# Target GPU: MI355 (gfx950) + + +FROM rocm/sgl-dev:rocm7-vllm-20250904 + +SHELL ["/bin/bash", "-ceuxo", "pipefail"] + +ARG MAX_JOBS=128 +ENV MAX_JOBS=${MAX_JOBS} + +# Set environment variables for gfx950 +ENV GPU_ARCH=gfx950 +ENV PYTORCH_ROCM_ARCH=gfx950 +ENV GPU_ARCH_LIST=gfx950 +ENV AMDGPU_TARGET=gfx950 + + +########################################### +##############1. Install AITER############# +########################################### +WORKDIR /app + +RUN pip uninstall -y aiter || true +RUN rm -rf aiter +RUN git clone https://github.com/ROCm/aiter.git \ + && cd aiter \ + && git checkout v0.1.7.post2 \ + && git submodule update --init --recursive \ + && GPU_ARCHS=gfx950 python setup.py develop +########################################### +########################################### +########################################### + + +########################################### +####2. Install TransformerEngine for gfx950 +########################################### +WORKDIR /app + +RUN rm -rf TransformerEngine +RUN git clone https://github.com/ROCm/TransformerEngine.git \ + && cd TransformerEngine \ + && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ + && git submodule update --init --recursive + +ENV NVTE_FRAMEWORK=pytorch +ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_USE_HIPBLASLT=1 +ENV NVTE_USE_ROCM=1 +ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + +RUN cd TransformerEngine && pip install . -v +########################################### +########################################### +########################################### + + +######################################### +####3. Install Megatron-LM (NVIDIA version) +######################################### +WORKDIR /app + +RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall + +RUN pip uninstall -y megatron-core || true +RUN rm -rf Megatron-LM +RUN git clone https://github.com/NVIDIA/Megatron-LM \ + && cd Megatron-LM \ + && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ + && pip install -e . +######################################### +######################################### +######################################### + + +######################################## +############ 4. Install mbridge######### +######################################## +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps +######################################## +######################################## +######################################## + + +######################################## +######5. Install Ray#################### +######################################## +RUN pip uninstall ray -y || true +RUN pip install "ray[data,train,tune,serve]==2.47.1" +######################################## +######################################## +######################################## + + +######################################### +###6. Install torch_memory_saver######### +######################################### +RUN pip install torch_memory_saver +######################################### +######################################### + + +####################################### +####7. Install Apex for ROCm########### +####################################### +WORKDIR /app + +RUN pip uninstall -y apex || true +RUN rm -rf apex +RUN git clone https://github.com/ROCm/apex.git \ + && cd apex \ + && python setup.py install +####################################### +####################################### +####################################### + + +######################################## +###8. Install slime agent framework deps +######################################## +RUN pip install pydra_config==0.0.15 +RUN pip install together +RUN pip install google-generativeai +RUN pip install tensorboard +######################################## +######################################## +######################################## + + +######################################## +###9. Set performance environment vars## +######################################## +ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 +ENV SGLANG_USE_AITER=1 +ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 +ENV SGLANG_MOE_PADDING=1 +ENV SGLANG_SET_CPU_AFFINITY=1 +ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 +ENV SGLANG_USE_ROCM700A=1 +ENV NCCL_MIN_NCHANNELS=112 +ENV VLLM_FP8_PADDING=1 +ENV VLLM_FP8_ACT_PADDING=1 +ENV VLLM_FP8_WEIGHT_PADDING=1 +ENV VLLM_FP8_REDUCE_CONV=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 +######################################## +######################################## +######################################## + + +########################################### +##############Install SGLang############### +########################################### +WORKDIR /app + +# Install prerequisites +RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 + +# Clone SGLang +RUN pip uninstall -y sgl_kernel sglang || true +RUN rm -rf sglang +RUN git clone https://github.com/sgl-project/sglang.git \ + && cd sglang \ + && git checkout v0.5.6 + +# Build sgl-kernel for gfx950 +RUN cd sglang/sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && AMDGPU_TARGET=gfx950 python setup_rocm.py install + +# Install SGLang +RUN cd sglang \ + && rm -rf python/pyproject.toml \ + && mv python/pyproject_other.toml python/pyproject.toml \ + && pip install -e "python[all_hip]" + +# Test SGLang installation +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN python -m pip cache purge +########################################### +########################################### +########################################### + + +########################################### +#### APPLY PATCHES (gfx950/MI355) ######### +########################################### + +# Copy patches from slime repo +COPY amd_patch/latest /app/patch + +# Apply Megatron patches +RUN cd /app/Megatron-LM \ + && git apply /app/patch/amd_megatron_fused_kernels_init.patch \ + && git apply /app/patch/megatron.patch --3way \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi \ + && pip install -e . -v + +# Apply SGLang patch +RUN cd /app/sglang \ + && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi + +# Copy MOE configs for gfx950/MI355 +RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ + /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' 2>/dev/null | while read f; do \ + cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ + cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ +done + +########################################### +########################################### +########################################### + + +######################################## +#### Install additional packages######## +######################################## +RUN pip install sglang-router --force-reinstall +######################################## +######################################## +######################################## + + +######################################## +# Fix click/ray incompatibility with Python 3.10 +######################################## +RUN pip install click==8.2.1 +######################################## +######################################## +######################################## + + +WORKDIR /app + +CMD ["/usr/bin/bash"] + diff --git a/docker/README.md b/docker/README.md index 92f559e72..156169c72 100644 --- a/docker/README.md +++ b/docker/README.md @@ -5,10 +5,10 @@ We will publish 2 kinds of docker images: 2. latest version, which aligns to `lmsysorg/sglang:latest`. current stable version is: -- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) history versions: -- sglang v0.5.0rc0-cu126 (8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f), megatron 48406695c4efcf1026a7ed70bb390793918dd97b +- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) The command to build: diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index 0526ac55a..3a56ff4c2 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -219,14 +219,14 @@ index 6aec66e6d..6ca48b55f 100644 mtp_loss = loss_mask * mtp_loss if self.training: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index a36b67364..8739270f2 100644 +index a36b67364..ed8883e32 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) -+ if param_group["step"] is None: ++ if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 055e09096..de12cdd43 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -215,7 +215,7 @@ index 932f52aeb..79c6b664f 100644 hidden_states = self._communicate_simple_fn( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py -index 3293a8a59..02999afd0 100644 +index 3293a8a59..a075b71ce 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -84,15 +84,12 @@ class RMSNorm(CustomOp): @@ -236,7 +236,7 @@ index 3293a8a59..02999afd0 100644 self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( -@@ -105,15 +102,16 @@ class RMSNorm(CustomOp): +@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, @@ -255,7 +255,17 @@ index 3293a8a59..02999afd0 100644 return rms_norm_batch_invariant( x, self.weight.data, -@@ -179,17 +177,35 @@ class RMSNorm(CustomOp): + self.variance_epsilon, + ) + if residual is not None: ++ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). ++ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, diff --git a/docker/patch/v0.5.6/megatron.patch b/docker/patch/v0.5.6/megatron.patch new file mode 100644 index 000000000..3a56ff4c2 --- /dev/null +++ b/docker/patch/v0.5.6/megatron.patch @@ -0,0 +1,869 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index ccf5242a2..9b6d3e31f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li + _restore_dict_types(x_val, templ_val) + + ++@dataclass ++class MCoreMetadata(Metadata): ++ """Metadata with mcore specific data.""" ++ ++ # holds data related to flattened_range ++ # TODO: remove when flattened_range is properly removed ++ mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor ++ ++ + @dataclass(frozen=True) + class MCoreSavePlan(SavePlan): + """SavePlan with MCore specific data.""" +@@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): + def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: + """Merges MCore data for all plans.""" + global_plan, metadata = super().create_global_plan(all_plans) +- metadata.mcore_data = dict( ++ mcore_data = dict( + ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] + ) ++ metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) + return global_plan, metadata + + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: +@@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b43..4451f2776 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index 7727efe1e..966fe652a 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index 860ee64a9..80944b702 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( + qk_l2_norm: Optional[bool] = False, + use_te_op_fuser: Optional[bool] = False, + use_kitchen: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( + ), + ), + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map={ + "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", + "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index 6aec66e6d..6ca48b55f 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -355,6 +355,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post +@@ -410,6 +411,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -431,6 +433,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -446,7 +449,7 @@ class GPTModel(LanguageModule): + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + +- if mtp_in_postprocess: ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -465,25 +468,37 @@ class GPTModel(LanguageModule): + if not self.post_process: + return hidden_states + +- if self.mtp_process: +- mtp_labels = labels.clone() ++ if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output +- mtp_logits, _ = self.output_layer( +- hidden_states_list[mtp_layer_number + 1], +- weight=output_weight, +- runtime_gather_output=runtime_gather_output, ++ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} ++ output_layer_buffers = dict(self.output_layer.named_buffers()) ++ mtp_logits, _ = torch.func.functional_call( ++ self.output_layer, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden_states_list[mtp_layer_number + 1],), ++ { ++ "weight": output_weight.detach() if output_weight else None, ++ "runtime_gather_output": runtime_gather_output, ++ }, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. +- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) +- loss_mask, num_tokens = roll_tensor( +- loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ new_loss_mask, num_tokens = roll_tensor( ++ loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params + ) ++ loss_mask = new_loss_mask * loss_mask + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index a36b67364..ed8883e32 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a40c85a88..86688c331 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -9,6 +9,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index 63ee9d1f5..b90b744c1 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py +index c749bac43..dde8d50e7 100644 +--- a/megatron/core/transformer/attention.py ++++ b/megatron/core/transformer/attention.py +@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") +- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) ++ if self.config.use_gated_attention: ++ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) ++ else: ++ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix="qkv") + + # =================================================== +@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): + # Output. [sq, b, h] + # ================= + ++ if self.config.use_gated_attention: ++ nvtx_range_push(suffix="sigmoid_gate") ++ core_attn_out = core_attn_out * torch.sigmoid(gate) ++ nvtx_range_pop(suffix="sigmoid_gate") ++ + nvtx_range_push(suffix="linear_proj") + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix="linear_proj") +@@ -879,19 +887,34 @@ class SelfAttention(Attention): + model_comm_pgs=model_comm_pgs, + ) + +- self.linear_qkv = build_module( +- submodules.linear_qkv, +- self.config.hidden_size, +- self.query_projection_size + 2 * self.kv_projection_size, +- config=self.config, +- init_method=self.config.init_method, +- gather_output=False, +- bias=self.config.add_bias_linear or self.config.add_qkv_bias, +- skip_bias_add=False, +- is_expert=False, +- tp_comm_buffer_name='qkv', +- tp_group=self.model_comm_pgs.tp, +- ) ++ if self.config.use_gated_attention: ++ self.linear_qgkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ 2 * (self.query_projection_size + self.kv_projection_size), ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) ++ else: ++ self.linear_qkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ self.query_projection_size + 2 * self.kv_projection_size, ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( +@@ -1036,6 +1059,65 @@ class SelfAttention(Attention): + + return query, key, value + ++ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 ++ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): ++ """ ++ Derives `query`, `key` and `value` tensors from `hidden_states`. ++ """ ++ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] ++ mixed_qgkv, _ = self.linear_qgkv(hidden_states) ++ ++ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] ++ new_tensor_shape = mixed_qgkv.size()[:-1] + ( ++ self.num_query_groups_per_partition, ++ ( ++ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) ++ * self.hidden_size_per_attention_head ++ ), ++ ) ++ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) ++ ++ split_arg_list = [ ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ self.hidden_size_per_attention_head, ++ self.hidden_size_per_attention_head, ++ ] ++ ++ if SplitAlongDim is not None: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) ++ else: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) ++ ++ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] ++ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) ++ gate = gate.reshape(query.size(0), query.size(1), -1) ++ ++ if self.q_layernorm is not None: ++ query = self.q_layernorm(query) ++ ++ if self.k_layernorm is not None: ++ key = self.k_layernorm(key) ++ ++ if self.config.test_mode: ++ self.run_realtime_tests() ++ ++ return query, gate, key, value ++ + def backward_dw(self) -> NoReturn: + """Execute weight update operations""" + self._backward_qkv_proj() +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 235b6f6af..fbcffe278 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -566,6 +566,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 6b20b8622..459e65921 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -156,6 +156,9 @@ class TopKRouter(Router): + self.local_tokens_per_expert = None + self.expert_bias = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index b7884e18e..f0104f861 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -105,17 +106,21 @@ def tie_output_layer_state_dict( + ) + + +-def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): +- """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. + +- This function extends the original roll_tensor to support Context Parallelism, which allows +- MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, +- and tensor rolling requires communication between adjacent CP ranks to properly handle the +- boundary conditions. ++def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): ++ """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. ++ ++ This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. ++ When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires ++ communication between adjacent CP ranks to properly handle the boundary conditions. ++ When packed sequences are used, rolling is performed within each individual sequence boundary ++ to prevent mixing tokens between different packed sequences. + + For CP=1 (default behavior): Uses standard torch.roll with zero padding + For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges + boundary elements between adjacent CP ranks to maintain sequence continuity. ++ For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. ++ + + Args: + tensor (Tensor): The input tensor to roll. +@@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + dims (int): The dimension to roll (typically -1 for sequence dimension). + cp_group (ProcessGroup): The context parallelism process group. If None or size=1, + falls back to standard rolling behavior. ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ If provided, rolling respects sequence boundaries. + Returns: + tuple: (rolled_tensor, sum_of_rolled_tensor) + """ ++ ++ if packed_seq_params is not None: ++ return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) ++ + # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) + if cp_group is None or cp_group.size() == 1: + rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) +@@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + + return rolled_tensor, rolled_tensor.sum() + ++def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): ++ """Roll tensor with packed sequence support. ++ ++ This function handles rolling for packed sequences by respecting sequence boundaries ++ defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual ++ sequence to prevent mixing tokens between different packed sequences. When Context ++ Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata ++ so we slice out the portion of every packed sequence that lives on the current rank and ++ reuse the standard CP boundary exchange to populate the rolling window. ++ ++ Args: ++ tensor (Tensor): The input tensor to roll. ++ shifts (int): The shift of the tensor (typically -1 for MTP). ++ dims (int): The dimension to roll (typically -1 for sequence dimension). ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ cp_group (ProcessGroup): The context parallelism process group. ++ ++ Returns: ++ tuple: (rolled_tensor, sum_of_rolled_tensor) ++ """ ++ ++ # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. ++ assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." ++ assert shifts == -1, "Packed sequence roll only supports a single-token left shift." ++ cu_seqlens = packed_seq_params.cu_seqlens_q ++ assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." ++ ++ rolled_tensor = tensor.clone() ++ ++ cp_size = cp_group.size() if cp_group is not None else 1 ++ if cp_size == 1: ++ # CP disabled: simply roll inside each packed sequence boundary. ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ seq_slice = tensor[..., start_idx:end_idx] ++ rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) ++ rolled_seq[..., shifts:] = 0 ++ rolled_tensor[..., start_idx:end_idx] = rolled_seq ++ return rolled_tensor, rolled_tensor.sum() ++ ++ # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). ++ local_rank = torch.distributed.get_rank(group=cp_group) ++ global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) ++ next_rank = global_ranks[(local_rank + 1) % cp_size] ++ prev_rank = global_ranks[(local_rank - 1) % cp_size] ++ ++ # iterate over each sequence individually ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ ++ # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx ++ local_start_idx = start_idx // cp_size ++ local_end_idx = end_idx // cp_size ++ tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() ++ ++ # The following code is very similar as the code in roll_tensor function ++ local_chunks = tensor_slice.chunk(2, dim=dims) ++ rolled_chunks = [ ++ torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks ++ ] ++ ++ tensor_send_list = [] ++ tensor_recv_list = [] ++ for chunk in rolled_chunks: ++ boundary = chunk.select(dims, shifts).contiguous().clone() ++ tensor_send_list.append(boundary) ++ tensor_recv_list.append(torch.empty_like(boundary)) ++ ++ ops = [] ++ if local_rank != 0: ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) ++ else: ++ tensor_recv_list[1].zero_() ++ ++ if local_rank != cp_size - 1: ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) ++ else: ++ tensor_recv_list[0].copy_(tensor_send_list[1]) ++ ++ for op in ops: ++ op.wait() ++ ++ index = [slice(None)] * rolled_chunks[0].dim() ++ index[dims] = shifts ++ for chunk, recv in zip(rolled_chunks, tensor_recv_list): ++ chunk[tuple(index)] = recv ++ ++ seq_result = torch.cat(rolled_chunks, dim=dims) ++ ++ # update the rolled tensor ++ rolled_tensor[..., local_start_idx:local_end_idx] = seq_result ++ ++ return rolled_tensor, rolled_tensor.sum() + + class MTPLossLoggingHelper: + """Helper class for logging MTP losses.""" +@@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): + def _get_embeddings( + self, + input_ids: torch.Tensor, +- position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, ++ position_ids: Optional[torch.Tensor] = None, ++ packed_seq_params: Optional[PackedSeqParams] = None, + ): + """ + Preprocesses input data for the Multi-Token Prediction (MTP) layers. +@@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): + sequence length, b is the batch size, and h is the hidden size. + """ + # Calc logits for the current Multi-Token Prediction (MTP) layers. +- input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) +- position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) ++ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ ++ # Prepare/roll position ids only when applicable. ++ if position_ids is None: ++ # Fallback position ids for learned absolute embedding. ++ seq_len = input_ids.size(-1) ++ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) ++ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) ++ ++ position_ids, _ = roll_tensor( ++ position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): +- """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" ++ """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" + if self.config.fp8: + from megatron.core.extensions.transformer_engine import te_checkpoint + + return te_checkpoint( +- forward_func, ++ run, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +@@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + assert context is None, f"multi token prediction + cross attention is not yet supported." +- assert ( +- packed_seq_params is None +- ), f"multi token prediction + sequence packing is not yet supported." + + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, ++ packed_seq_params=packed_seq_params, + ) + + if self.config.recompute_granularity == 'full' and self.training: +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index d55bebe7e..1eecbbd38 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): + qk_layernorm: bool = False + """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ use_gated_attention: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 84f22bdea..f0f3f8e86 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index e3459c5ee..7346bf35b 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -937,8 +937,6 @@ def validate_args(args, defaults={}): + # MoE Spec check + if args.num_experts == 0: + args.num_experts = None +- if args.num_experts is not None: +- assert args.spec is None, "Model Spec must be None when using MoEs" + if args.num_experts is not None and args.moe_ffn_hidden_size is None: + args.moe_ffn_hidden_size = args.ffn_hidden_size + print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") +@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): + if args.is_hybrid_model: + kw_args['is_hybrid_model'] = args.is_hybrid_model + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ kw_args['use_gated_attention'] = args.use_gated_attention ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 5cf222ccc..d1554ca4c 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): + f"The transformers library must be installed to use huggingface_tokenizer_provider" + ) + ++ if "trust_remote_code" not in kwargs: ++ kwargs["trust_remote_code"] = True + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs diff --git a/docker/patch/v0.5.6/sglang.patch b/docker/patch/v0.5.6/sglang.patch new file mode 100644 index 000000000..de12cdd43 --- /dev/null +++ b/docker/patch/v0.5.6/sglang.patch @@ -0,0 +1,2053 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index ef52bda7f..537d892dc 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -296,6 +296,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index d4414d084..c5fb10155 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1074,6 +1074,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index 952374ed5..239ac2571 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -305,6 +305,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py +index 86c53f26b..52acf95b9 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl.py +@@ -380,3 +380,9 @@ class PyNcclCommunicator: + + self.disabled = old_disable + self.stream = old_stream ++ ++ def nccl_pause(self): ++ self.nccl.ncclPause(self.comm) ++ ++ def nccl_resume(self): ++ self.nccl.ncclResume(self.comm) +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +index 6b12f2922..7028a4e46 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +@@ -304,6 +304,17 @@ class NCCLLibrary: + Function("ncclGroupEnd", ncclResult_t, []), + ] + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ exported_functions.extend( ++ [ ++ # ncclResult_t ncclPause(ncclComm_t comm); ++ Function("ncclPause", ncclResult_t, [ncclComm_t]), ++ # ncclResult_t ncclResume(ncclComm_t comm); ++ Function("ncclResume", ncclResult_t, [ncclComm_t]), ++ Function("ncclSetGroupID", ncclResult_t, [ctypes.c_int]), ++ ] ++ ) ++ + exported_functions_symm_mem = [ + # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); + Function( +@@ -551,6 +562,12 @@ class NCCLLibrary: + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + ++ def ncclPause(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclPause"](comm)) ++ ++ def ncclResume(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclResume"](comm)) ++ + + __all__ = [ + "NCCLLibrary", +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index cf90f6fe0..11d26df81 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1780,7 +1780,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py +index 67a082ea6..390365864 100644 +--- a/python/sglang/srt/entrypoints/engine.py ++++ b/python/sglang/srt/entrypoints/engine.py +@@ -183,6 +183,7 @@ class Engine(EngineBase): + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + stream: bool = False, + bootstrap_host: Optional[Union[List[str], str]] = None, + bootstrap_port: Optional[Union[List[int], int]] = None, +@@ -218,6 +219,7 @@ class Engine(EngineBase): + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + return_hidden_states=return_hidden_states, ++ return_routed_experts=return_routed_experts, + stream=stream, + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, +diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py +index 9f556a885..992843285 100644 +--- a/python/sglang/srt/layers/attention/vision.py ++++ b/python/sglang/srt/layers/attention/vision.py +@@ -518,11 +518,25 @@ class VisionAttention(nn.Module): + self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size + + if self.qk_normalization: ++ norm_kwargs = ( ++ dict( ++ weight_dtype=torch.float32, ++ cast_x_before_out_mul=True, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) + self.q_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + self.k_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + + # Select attention backend via a unified method +@@ -648,6 +662,15 @@ class VisionAttention(nn.Module): + if x.dim() == 2: + x = x.unsqueeze(0) + assert x.dim() == 3, x.shape ++ if ( ++ get_global_server_args().rl_on_policy_target is not None ++ and position_embeddings is not None ++ ): ++ assert isinstance(position_embeddings, tuple), ( ++ "expected position_embeddings to be a tuple of two tensors,\n" ++ f"but got {type(position_embeddings)}, change if needed" ++ ) ++ position_embeddings = tuple(p.to(x.dtype) for p in position_embeddings) + x_shape = x.shape + bsz, s, _ = x_shape + head = self.num_attention_heads_per_partition +diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py +index 932f52aeb..79c6b664f 100644 +--- a/python/sglang/srt/layers/communicator.py ++++ b/python/sglang/srt/layers/communicator.py +@@ -372,6 +372,7 @@ class LayerCommunicator: + residual: torch.Tensor, + forward_batch: ForwardBatch, + quant_format: str = "", ++ post_residual_addition: Optional[torch.Tensor] = None, + ): + if get_attn_tp_context().input_scattered: + hidden_states, residual = self._tp_reduce_scatter( +@@ -453,7 +454,9 @@ class LayerCommunicator: + ) + else: + hidden_states, residual = self.input_layernorm( +- hidden_states, residual ++ hidden_states, ++ residual, ++ post_residual_addition, + ) + + hidden_states = self._communicate_simple_fn( +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index 3293a8a59..a075b71ce 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -84,15 +84,12 @@ class RMSNorm(CustomOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + if is_batch_invariant_mode_enabled(): + if ( + residual is not None + or get_global_server_args().rl_on_policy_target == "fsdp" + ): +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + return rms_norm_batch_invariant( + x, + self.weight.data, + self.variance_epsilon, + ) + if residual is not None: ++ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). ++ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype ++ ++ if residual is not None and not self.fp32_residual: ++ x = ( ++ x ++ + residual ++ + ( ++ post_residual_addition ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: +- x = x + residual.to(torch.float32) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ if residual is not None and self.fp32_residual: ++ x = ( ++ x ++ + residual.to(torch.float32) ++ + ( ++ post_residual_addition.to(torch.float32) ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index 522865765..733bad5f2 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -841,11 +841,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index e7d5a67cc..639e47163 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + direct_register_custom_op, +@@ -626,7 +627,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +new file mode 100644 +index 000000000..e16817f1f +--- /dev/null ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -0,0 +1,279 @@ ++import logging ++from abc import ABC ++from contextlib import contextmanager ++from typing import Optional ++ ++import numpy as np ++import torch ++ ++from sglang.srt.configs.model_config import ModelConfig ++from sglang.srt.layers.dp_attention import ( ++ get_attention_dp_rank, ++ get_dp_local_info, ++ is_dp_attention_enabled, ++) ++from sglang.srt.mem_cache.memory_pool import ReqToTokenPool ++from sglang.srt.server_args import get_global_server_args ++ ++logger = logging.getLogger(__name__) ++ ++_GB = 1024 * 1024 * 1024 ++_MB = 1024 * 1024 ++ ++ ++def get_tensor_size_bytes(t: torch.Tensor): ++ return np.prod(t.shape) * t.dtype.itemsize ++ ++ ++class _RoutedExpertsDeviceCache: ++ def __init__( ++ self, ++ max_running_requests: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ) -> None: ++ self.buffer = torch.zeros( ++ ( ++ max( ++ get_global_server_args().chunked_prefill_size ++ * get_global_server_args().dp_size, ++ max_running_requests, ++ ), ++ num_hidden_layers, ++ num_experts_per_tok + num_fused_shared_experts, ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): ++ assert layer_id is not None, "capturing routing experts but get layer_id None" ++ batch, _ = topk_ids.shape ++ self.buffer[:batch, layer_id, :] = topk_ids ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_MB = self.get_buffer_size_bytes() / _MB ++ logger.info( ++ f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" ++ ) ++ ++ ++class _RoutedExpertsHostCache: ++ def __init__( ++ self, ++ num_tokens: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ ) -> None: ++ self.num_tokens = num_tokens ++ self.buffer = torch.zeros( ++ ( ++ num_tokens, ++ num_hidden_layers, ++ num_experts_per_tok, ++ ), ++ dtype=torch.int32, ++ device="cpu", ++ pin_memory=True, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): ++ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_GB = self.get_buffer_size_bytes() / _GB ++ logger.info( ++ f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" ++ ) ++ ++ ++class RoutedExpertsCapturer(ABC): ++ @staticmethod ++ def create( ++ enable: bool, ++ model_config: ModelConfig, ++ num_fused_shared_experts: int, ++ num_tokens: int, ++ max_running_requests: int, ++ device: str, ++ ): ++ if enable: ++ return _RoutedExpertsCapturerReal( ++ model_config, ++ num_tokens=num_tokens, ++ max_running_requests=max_running_requests, ++ num_fused_shared_experts=num_fused_shared_experts, ++ device=device, ++ ) ++ else: ++ return _RoutedExpertsCapturerNoop() ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ raise NotImplementedError ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ raise NotImplementedError ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ raise NotImplementedError ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ raise NotImplementedError ++ ++ def get_device_cache(self): ++ raise NotImplementedError ++ ++ ++class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): ++ """Capturer for routed experts with host buffer""" ++ ++ def __init__( ++ self, ++ model_config: ModelConfig, ++ num_tokens: int, ++ max_running_requests: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ): ++ self.forward_batch = None ++ self.num_fused_shared_experts = num_fused_shared_experts ++ self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers ++ self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok ++ ++ self.host_cache = _RoutedExpertsHostCache( ++ num_tokens=num_tokens, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ ) ++ ++ self.device_cache = _RoutedExpertsDeviceCache( ++ max_running_requests=max_running_requests, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ num_fused_shared_experts=self.num_fused_shared_experts, ++ device=device, ++ ) ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ if is_dp_attention_enabled(): ++ local_start_pos, local_num_tokens = get_dp_local_info(self.forward_batch) ++ # handle with cuda graph padding ++ if can_run_graph: ++ local_start_pos = get_attention_dp_rank() * cuda_graph_batch ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_start_pos = 0 ++ local_end_pos = device_loc.shape[0] ++ ++ self.host_cache.buffer[cpu_loc] = self.device_cache.buffer[ ++ local_start_pos:local_end_pos, :, : self.num_experts_per_tok ++ ].cpu() ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ cache_pool_idx = ( ++ req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() ++ ) ++ return self.get_host_cache().buffer[cache_pool_idx] ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ self.forward_batch = forward_batch ++ yield ++ ++ def get_host_cache(self): ++ return self.host_cache ++ ++ def get_device_cache(self): ++ return self.device_cache ++ ++ ++class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): ++ def __init__(self): ++ pass ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ pass ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ pass ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ pass ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ pass ++ ++ def get_device_cache(self): ++ pass ++ ++ ++_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() ++ ++ ++def get_global_experts_capturer(): ++ return _global_expert_capturer ++ ++ ++def set_global_experts_capturer(capturer: RoutedExpertsCapturer): ++ global _global_expert_capturer ++ _global_expert_capturer = capturer +\ No newline at end of file +diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py +index a802647e8..0fd550c0c 100644 +--- a/python/sglang/srt/layers/moe/topk.py ++++ b/python/sglang/srt/layers/moe/topk.py +@@ -48,6 +48,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( + ) + from sglang.srt.layers.dp_attention import is_allocation_symmetric + from sglang.srt.layers.moe import get_moe_runner_backend ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -212,6 +213,7 @@ class TopK(CustomOp): + self, + top_k: int, + *, ++ layer_id: Optional[int] = None, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, +@@ -233,6 +235,7 @@ class TopK(CustomOp): + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + ++ self.layer_id = layer_id + self.topk_config = TopKConfig( + top_k=top_k, + use_grouped_topk=use_grouped_topk, +@@ -260,6 +263,7 @@ class TopK(CustomOp): + self.topk_config.torch_native = True + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -309,6 +313,7 @@ class TopK(CustomOp): + ): + topk_output = select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -326,6 +331,7 @@ class TopK(CustomOp): + ) -> TopKOutput: + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -856,6 +862,7 @@ def select_experts( + router_logits: torch.Tensor, + topk_config: TopKConfig, + *, ++ layer_id: Optional[int] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> StandardTopKOutput: +@@ -983,7 +990,10 @@ def select_experts( + ) + + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) +- ++ get_global_experts_capturer().capture( ++ layer_id=layer_id, ++ topk_ids=topk_ids, ++ ) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + +diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py +index 70466bb20..cd85fc2f2 100644 +--- a/python/sglang/srt/layers/moe/utils.py ++++ b/python/sglang/srt/layers/moe/utils.py +@@ -284,7 +284,7 @@ def speculative_moe_a2a_backend_context(): + global MOE_A2A_BACKEND + original_backend = MOE_A2A_BACKEND + try: +- MOE_A2A_BACKEND = MoeA2ABackend.NONE ++ MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() + yield + finally: + MOE_A2A_BACKEND = original_backend +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 0cdb7e1ae..df8860409 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -15,7 +15,6 @@ from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +- get_compiler_backend, + is_cpu, + is_cuda, + is_hip, +@@ -132,9 +131,7 @@ class RotaryEmbedding(CustomOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1423,6 +1420,9 @@ class MRotaryEmbedding(RotaryEmbedding): + f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" + ) + ++ if get_global_server_args().rl_on_policy_target is not None: ++ self._forward_method = self.forward_native ++ + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible +@@ -1432,8 +1432,7 @@ class MRotaryEmbedding(RotaryEmbedding): + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + +- @torch.compile(dynamic=True, backend=get_compiler_backend()) +- def _forward_native( ++ def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1490,7 +1489,7 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def forward( ++ def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1507,14 +1506,12 @@ class MRotaryEmbedding(RotaryEmbedding): + """ + assert positions.ndim == 1 or positions.ndim == 2 + +- if positions.ndim == 2 and self.mrope_section and _is_cuda: +- return self._forward_triton(positions, query, key) +- elif _is_npu: +- return self._forward_npu(positions, query, key) +- else: +- return self._forward_native(positions, query, key) ++ # Use Triton kernel for multimodal (2D positions) with mrope ++ if positions.ndim == 2 and self.mrope_section: ++ return self.forward_triton(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + +- def _forward_triton( ++ def forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1563,15 +1560,19 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def _forward_npu( ++ def forward_npu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, ++ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + if query.shape[1] > 4096: +- return self._forward_native(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + rotary_mode = "half" + if self.is_neox_style: + rotary_mode = "half" +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 7f6f6a010..c4a673145 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -105,16 +105,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index 87922077e..8cb6bad8d 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + s.sent_offset = len(output_str) + output_strs.append(incremental_output) + ++ output_routed_experts = [] ++ if recv_obj.output_routed_experts is not None: ++ output_routed_experts = [ ++ ( ++ output_routed_experts.tolist() ++ if output_routed_experts is not None ++ else [] ++ ) ++ for output_routed_experts in recv_obj.output_routed_experts ++ ] + return BatchStrOutput( + rids=recv_obj.rids, + http_worker_ipcs=recv_obj.http_worker_ipcs, +@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, + output_token_entropy_val=recv_obj.output_token_entropy_val, + output_hidden_states=recv_obj.output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=recv_obj.retraction_counts, +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index e34736cc4..5e5997a1a 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -23,6 +23,8 @@ from dataclasses import dataclass, field + from enum import Enum + from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + ++import torch ++ + from sglang.srt.lora.lora_registry import LoRARef + from sglang.srt.managers.schedule_batch import BaseFinishReason + from sglang.srt.multimodal.mm_utils import has_valid_data +@@ -175,6 +177,8 @@ class GenerateReqInput(BaseReq): + log_metrics: bool = True + # Whether to return hidden states + return_hidden_states: Union[List[bool], bool] = False ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None +@@ -592,6 +596,7 @@ class GenerateReqInput(BaseReq): + if isinstance(self.return_hidden_states, list) + else self.return_hidden_states + ), ++ return_routed_experts=self.return_routed_experts, + modalities=self.modalities[i] if self.modalities else None, + session_params=self.session_params, + lora_path=self.lora_path[i] if self.lora_path is not None else None, +@@ -655,6 +660,9 @@ class TokenizedGenerateReqInput(BaseReq): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False ++ + # The input embeds + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + +@@ -910,6 +918,9 @@ class BatchTokenIDOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[torch.Tensor] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +@@ -989,6 +1000,9 @@ class BatchStrOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[List[int]] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index c4c5a9ebb..1450c5fd8 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -450,6 +450,7 @@ class Req: + session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_port: Optional[int] = None, +@@ -629,6 +630,12 @@ class Req: + self.output_topk_p = None + self.output_topk_index = None + ++ # capture routed experts ++ self.return_routed_experts = return_routed_experts ++ self.routed_experts: Optional[torch.Tensor] = ( ++ None # cpu tensor: shape (seqlen, topk) ++ ) ++ + # Embedding (return values) + self.embedding = None + +@@ -992,6 +999,7 @@ class Req: + self.retraction_count += 1 + + self.prefix_indices = torch.empty((0,), dtype=torch.int64) ++ self.routed_experts = [] + self.last_node = None + self.swa_uuid_for_lock = None + self.extend_input_len = 0 +@@ -1159,6 +1167,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured experts ++ return_routed_experts: bool = False ++ + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + +@@ -1206,6 +1217,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + device=req_to_token_pool.device, + spec_algorithm=spec_algorithm, + return_hidden_states=any(req.return_hidden_states for req in reqs), ++ return_routed_experts=any(req.return_routed_experts for req in reqs), + is_prefill_only=all(req.is_prefill_only for req in reqs), + chunked_req=chunked_req, + dllm_config=dllm_config, +@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.req_pool_indices = req_pool_indices_tensor + self.orig_seq_lens = orig_seq_lens_tensor + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc.cpu() + self.input_embeds = ( + torch.tensor(input_embeds).to(self.device, non_blocking=True) + if input_embeds +@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) ++ out_cache_loc_cpu = torch.cat( ++ [self.out_cache_loc_cpu, running_batch.out_cache_loc_cpu] ++ ) + + self.merge_batch(running_batch) + self.input_ids = input_ids + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc_cpu + + # For overlap scheduler, the output_ids has one step delay + delta = 0 if self.enable_overlap else -1 +@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) + self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) ++ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu") + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens_sum = 0 + self.extend_num_tokens = 0 +@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + # Allocate memory + self.out_cache_loc = alloc_for_decode(self, token_per_req=1) ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + # Update req-level memory management fields + for req in self.reqs: +@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] + self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum = self.seq_lens.sum().item() + self.output_ids = self.output_ids[keep_indices_device] + self.return_logprob = any(req.return_logprob for req in self.reqs) +@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) + self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum += other.seq_lens_sum + if self.output_ids is not None: + self.output_ids = torch.cat([self.output_ids, other.output_ids]) +@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + seq_lens=self.seq_lens, + orig_seq_lens=self.orig_seq_lens, + out_cache_loc=self.out_cache_loc, ++ out_cache_loc_cpu=self.out_cache_loc_cpu, + seq_lens_cpu=seq_lens_cpu, + seq_lens_sum=self.seq_lens_sum, + return_logprob=self.return_logprob, +@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: + # Sampling info + sampling_info: SamplingBatchInfo + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence lengths, Qwen-1M related + orig_seq_lens: Optional[torch.Tensor] = None + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index b801fd8f8..9e27cc825 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -1305,6 +1305,7 @@ class Scheduler( + input_embeds=recv_req.input_embeds, + custom_logit_processor=recv_req.custom_logit_processor, + return_hidden_states=recv_req.return_hidden_states, ++ return_routed_experts=recv_req.return_routed_experts, + eos_token_ids=self.model_config.hf_eos_token_id, + bootstrap_host=recv_req.bootstrap_host, + bootstrap_port=recv_req.bootstrap_port, +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index c48f5f893..a9796c25f 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -9,6 +9,7 @@ import torch + from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +@@ -112,6 +113,14 @@ class SchedulerOutputProcessorMixin: + req.check_finished() + + if req.finished(): ++ req.routed_experts = ( ++ get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ ) ++ + release_kv_cache(req, self.tree_cache) + req.time_stats.completion_time = time.perf_counter() + elif not batch.decoding_reqs or req not in batch.decoding_reqs: +@@ -362,6 +371,12 @@ class SchedulerOutputProcessorMixin: + req.check_finished(new_accepted_len) + + if req.finished(): ++ req.routed_experts = get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ + if self.server_args.disaggregation_decode_enable_offload_kvcache: + # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes + if not self.decode_offload_manager.offload_kv_cache(req): +@@ -756,6 +771,7 @@ class SchedulerOutputProcessorMixin: + spec_accepted_tokens = [] + retraction_counts = [] + output_hidden_states = None ++ output_routed_experts = None + + queue_times = [] + forward_entry_times = [] +@@ -946,6 +962,10 @@ class SchedulerOutputProcessorMixin: + if output_hidden_states is None: + output_hidden_states = [] + output_hidden_states.append(req.hidden_states) ++ if req.return_routed_experts: ++ if output_routed_experts is None: ++ output_routed_experts = [] ++ output_routed_experts.append(req.routed_experts) + + if ( + req.finished() +@@ -994,6 +1014,7 @@ class SchedulerOutputProcessorMixin: + output_token_ids_logprobs_idx=output_token_ids_logprobs_idx, + output_token_entropy_val=None, + output_hidden_states=output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=retraction_counts, +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index f8ebfc1f4..a05449fac 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -127,6 +131,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -137,6 +148,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_pause() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_pause() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_pause() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_pause() ++ + torch.get_device_module().synchronize() + + return ReleaseMemoryOccupationReqOutput() +@@ -155,6 +180,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_resume() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_resume() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_resume() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_resume() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) + torch.distributed.barrier(self.tp_cpu_group) +@@ -167,6 +206,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index b90cf0616..98d71d896 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): + session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, + return_hidden_states=obj.return_hidden_states, ++ return_routed_experts=obj.return_routed_experts, + data_parallel_rank=obj.data_parallel_rank, + priority=obj.priority, + extra_key=obj.extra_key, +@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): + if getattr(recv_obj, "output_hidden_states", None): + meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + ++ if getattr(recv_obj, "output_routed_experts", None): ++ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] ++ + if isinstance(recv_obj, BatchStrOutput): + state.text += recv_obj.output_strs[i] + if self.server_args.stream_output and state.obj.stream: +@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): + return + + if len(recv_obj.input_token_logprobs_val) > 0: +- state.input_token_logprobs_val.extend( +- recv_obj.input_token_logprobs_val[recv_obj_index] +- ) +- state.input_token_logprobs_idx.extend( +- recv_obj.input_token_logprobs_idx[recv_obj_index] +- ) ++ if recv_obj.input_token_logprobs_val[recv_obj_index]: ++ state.input_token_logprobs_val.extend( ++ recv_obj.input_token_logprobs_val[recv_obj_index] ++ ) ++ state.input_token_logprobs_idx.extend( ++ recv_obj.input_token_logprobs_idx[recv_obj_index] ++ ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) +diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py +index 3a85e6a7e..2859dafa1 100644 +--- a/python/sglang/srt/model_executor/forward_batch_info.py ++++ b/python/sglang/srt/model_executor/forward_batch_info.py +@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( + set_dp_buffer_len, + set_is_extend_in_batch, + ) ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import get_compiler_backend, is_npu, support_triton + from sglang.srt.utils.common import ceil_align + +@@ -214,6 +215,9 @@ class ForwardBatch: + # The sum of all sequence lengths + seq_lens_sum: int + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence length without being chunked. Qwen-1M related. + orig_seq_lens: Optional[torch.Tensor] = None + +@@ -368,6 +372,7 @@ class ForwardBatch: + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + out_cache_loc=batch.out_cache_loc, ++ out_cache_loc_cpu=batch.out_cache_loc_cpu, + mm_inputs=batch.multimodal_inputs, + encoder_cached=batch.encoder_cached, + encoder_lens=batch.encoder_lens, +@@ -623,7 +628,10 @@ class ForwardBatch: + mm_input = batch.multimodal_inputs[batch_idx] + if self.forward_mode.is_decode(): + # 3 * N +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + mrope_positions_list[batch_idx] = torch.full( + (3, 1), + self.seq_lens[batch_idx] - 1, +@@ -640,7 +648,10 @@ class ForwardBatch: + batch.extend_seq_lens[batch_idx], + batch.extend_prefix_lens[batch_idx], + ) +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + # text only + mrope_positions = torch.tensor( + [ +@@ -823,6 +834,10 @@ class ForwardBatch: + ) + + self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) ++ if self.out_cache_loc_cpu is not None: ++ self.out_cache_loc_cpu = self._pad_tensor_to_size( ++ self.out_cache_loc_cpu, num_tokens ++ ) + if self.encoder_lens is not None: + self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) + self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 4d58278b7..8f50dc430 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( + set_is_extend_in_batch, + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import ( ++ RoutedExpertsCapturer, ++ get_global_experts_capturer, ++ set_global_experts_capturer, ++) + from sglang.srt.layers.pooler import EmbeddingPoolerOutput + from sglang.srt.layers.sampler import Sampler + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model +@@ -502,6 +507,10 @@ class ModelRunner: + server_args.max_running_requests, + server_args.max_total_tokens, + ) ++ ++ # Init routed experts capturer ++ self.init_routed_experts_capturer() ++ + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() +@@ -545,6 +554,40 @@ class ModelRunner: + # Initialize piecewise CUDA graph + self.init_piecewise_cuda_graphs() + ++ def init_routed_experts_capturer(self): ++ # TODO: the redundant logic with TpModelWorker ++ max_running_requests = min( ++ ( ++ self.max_total_num_tokens // 2 ++ if self.server_args.max_running_requests is None ++ else self.server_args.max_running_requests ++ // ( ++ self.server_args.dp_size ++ if self.server_args.enable_dp_attention ++ else 1 ++ ) ++ ), ++ self.req_to_token_pool.size, ++ ) ++ ++ if not self.server_args.disable_shared_experts_fusion and hasattr( ++ self.model, "num_fused_shared_experts" ++ ): ++ num_fused_shared_experts = self.model.num_fused_shared_experts ++ else: ++ num_fused_shared_experts = 0 ++ ++ set_global_experts_capturer( ++ RoutedExpertsCapturer.create( ++ enable=get_global_server_args().enable_return_routed_experts, ++ model_config=self.model_config, ++ num_fused_shared_experts=num_fused_shared_experts, ++ num_tokens=self.max_total_num_tokens + self.page_size, ++ max_running_requests=max_running_requests, ++ device=self.device, ++ ) ++ ) ++ + def model_specific_adjustment(self): + server_args = self.server_args + +@@ -792,7 +835,11 @@ class ModelRunner: + ) + with self.memory_saver_adapter.region( + GPU_MEMORY_TYPE_WEIGHTS, +- enable_cpu_backup=enable_cpu_backup, ++ enable_cpu_backup=( ++ self.server_args.enable_weights_cpu_backup ++ if not self.is_draft_worker ++ else True ++ ), + ): + self.model = get_model( + model_config=self.model_config, +@@ -2645,9 +2692,12 @@ class ModelRunner: + ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: + self.forward_pass_id += 1 + +- with get_global_expert_distribution_recorder().with_forward_pass( +- self.forward_pass_id, +- forward_batch, ++ with ( ++ get_global_expert_distribution_recorder().with_forward_pass( ++ self.forward_pass_id, ++ forward_batch, ++ ), ++ get_global_experts_capturer().with_forward(forward_batch), + ): + output = self._forward_raw( + forward_batch, +@@ -2656,6 +2706,13 @@ class ModelRunner: + reinit_attn_backend, + split_forward_count, + ) ++ # Copy cached routing experts' buffers back to CPU cache ++ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( ++ device_loc=forward_batch.out_cache_loc, ++ cpu_loc=forward_batch.out_cache_loc_cpu, ++ can_run_graph=output[1], ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index dc30b4f0a..f29dc4b71 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -667,6 +667,7 @@ class DeepseekV2MoE(nn.Module): + + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, ++ layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, +diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py +index ab1b6576b..dffd8f09a 100644 +--- a/python/sglang/srt/models/ernie4.py ++++ b/python/sglang/srt/models/ernie4.py +@@ -87,6 +87,7 @@ class Ernie4Moe(nn.Module): + + self.topk = TopK( + top_k=config.moe_k, ++ layer_id=layer_id, + renormalize=True, + use_grouped_topk=False, + correction_bias=self.gate.e_score_correction_bias, +diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py +index a9689b8f2..bc8538da8 100644 +--- a/python/sglang/srt/models/glm4_moe.py ++++ b/python/sglang/srt/models/glm4_moe.py +@@ -379,6 +379,17 @@ class Glm4MoeSparseMoeBlock(nn.Module): + + self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix)) + ++ self.topk = TopK( ++ top_k=self.top_k, ++ layer_id=self.layer_id, ++ renormalize=config.norm_topk_prob, ++ use_grouped_topk=True, ++ num_expert_group=config.n_group, ++ topk_group=config.topk_group, ++ correction_bias=self.gate.e_score_correction_bias, ++ routed_scaling_factor=self.routed_scaling_factor, ++ ) ++ + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.n_routed_experts + self.num_fused_shared_experts, + num_fused_shared_experts=self.num_fused_shared_experts, +diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py +index 9474700c4..398d622ff 100644 +--- a/python/sglang/srt/models/gpt_oss.py ++++ b/python/sglang/srt/models/gpt_oss.py +@@ -113,6 +113,7 @@ class GptOssSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, ++ layer_id=layer_id, + ) + + self.top_k = config.num_experts_per_tok +diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py +index fd513060a..a089475b7 100644 +--- a/python/sglang/srt/models/grok.py ++++ b/python/sglang/srt/models/grok.py +@@ -142,6 +142,7 @@ class Grok1MoE(nn.Module): + self.topk = TopK( + top_k=top_k, + renormalize=False, ++ layer_id=layer_id, + custom_routing_function=custom_routing_function, + ) + +diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py +index 7c6fd9e48..b20d28544 100644 +--- a/python/sglang/srt/models/hunyuan.py ++++ b/python/sglang/srt/models/hunyuan.py +@@ -150,6 +150,7 @@ class HunYuanSparseMoeBlock(nn.Module): + + self.topk = TopK( + top_k=top_k, ++ layer_id=layer_id, + renormalize=True if top_k > 1 else False, + ) + +diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py +index 3530609ba..01c89e893 100644 +--- a/python/sglang/srt/models/longcat_flash.py ++++ b/python/sglang/srt/models/longcat_flash.py +@@ -245,6 +245,7 @@ class LongcatFlashMoE(nn.Module): + renormalize=False, + use_grouped_topk=False, + correction_bias=self.router.e_score_correction_bias.data, ++ layer_id=layer_id, + ) + self.topk.forward = self.topk.forward_native + +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index ea33e81ef..561934dce 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -161,6 +161,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +@@ -581,7 +582,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 30b92acbd..a0d14895f 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -256,10 +256,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -289,10 +287,14 @@ class Qwen3DecoderLayer(nn.Module): + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + hidden_states, residual = self.layer_communicator.prepare_attn( +- hidden_states, residual, forward_batch ++ hidden_states, ++ residual, ++ forward_batch, ++ post_residual_addition=post_residual_addition, + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index 9737ac719..09c756918 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import RoutingMethodType + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.layers.radix_attention import RadixAttention +@@ -227,7 +228,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + use_grouped_topk=False, ++ layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -293,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -474,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -493,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def _apply_qk_norm( +@@ -751,9 +778,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index ed52f7ff4..8ce9fab9d 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -18,7 +18,6 @@ import re + from functools import lru_cache, partial + from typing import Callable, Iterable, List, Optional, Tuple, Union + +-import numpy as np + import torch + import torch.nn as nn + from einops import rearrange +@@ -349,83 +348,65 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + +- # TODO: use torch instand of np +- for t, h, w in grid_thw: +- h_idxs = np.linspace(0, num_grid_per_side - 1, h) +- w_idxs = np.linspace(0, num_grid_per_side - 1, w) ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) + +- h_idxs_floor = h_idxs.astype(int) +- w_idxs_floor = w_idxs.astype(int) +- h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) +- w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + +- idx_list[0].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[1].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[2].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[3].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side + +- weight_list[0].extend( +- ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t +- ) +- weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t) +- weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t) +- weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t) ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] + +- device = self.pos_embed.weight.device +- dtype = self.pos_embed.weight.dtype ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- p0 = ( +- self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None] +- ) +- p1 = ( +- self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None] +- ) +- p2 = ( +- self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None] ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- p3 = ( +- self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None] ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] + ) + +- patch_pos_embeds = p0 + p1 + p2 + p3 +- patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size +- for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( +- pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1) ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) +@@ -555,21 +536,27 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ ++ # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. ++ # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 ++ # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack ++ # The order matters because addition with different tensors is not associative in practice. + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, ++ post_residual_addition=deepstack_embeds, + ) + +- # process deepstack +- if ( +- input_deepstack_embeds is not None +- and layer_idx in self.deepstack_embed_to_decoder_layer +- ): +- sep = self.hidden_size * layer_idx +- hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size] +- + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { +diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py +index 4474f62d5..0e537c398 100644 +--- a/python/sglang/srt/models/step3_vl.py ++++ b/python/sglang/srt/models/step3_vl.py +@@ -129,6 +129,7 @@ class Step3TextMoEMLP(nn.Module): + top_k=config.moe_top_k, + renormalize=config.norm_expert_weight, + use_grouped_topk=False, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py +index 370aec2b6..47666d8f3 100644 +--- a/python/sglang/srt/multimodal/processors/base_processor.py ++++ b/python/sglang/srt/multimodal/processors/base_processor.py +@@ -13,6 +13,7 @@ from PIL import Image + from transformers import BaseImageProcessorFast + + from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + get_bool_env_var, + is_npu, +@@ -260,7 +261,9 @@ class BaseMultimodalProcessor(ABC): + and isinstance(processor.image_processor, BaseImageProcessorFast) + and not self.server_args.disable_fast_image_processor + ): +- if not _is_npu: ++ if get_global_server_args().rl_on_policy_target is not None: ++ kwargs["device"] = "cpu" ++ elif not _is_npu: + kwargs["device"] = "cuda" + elif processor.__class__.__name__ not in { + "Qwen2_5_VLProcessor", +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 8e7753dab..323788f39 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -535,6 +535,7 @@ class ServerArgs: + disable_fast_image_processor: bool = False + keep_mm_feature_on_device: bool = False + enable_return_hidden_states: bool = False ++ enable_return_routed_experts: bool = False + scheduler_recv_interval: int = 1 + numa_node: Optional[List[int]] = None + enable_deterministic_inference: bool = False +@@ -1966,6 +1967,9 @@ class ServerArgs: + "Enable deterministic inference because of rl_on_policy_target." + ) + self.enable_deterministic_inference = True ++ ++ # For VLM ++ os.environ["SGLANG_VLM_CACHE_SIZE_MB"] = "0" + # TODO remove this environment variable as a whole + os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1" + +@@ -3705,6 +3709,11 @@ class ServerArgs: + action="store_true", + help="Enable returning hidden states with responses.", + ) ++ parser.add_argument( ++ "--enable-return-routed-experts", ++ action="store_true", ++ help="Enable returning routed experts of each layer with responses.", ++ ) + parser.add_argument( + "--scheduler-recv-interval", + type=int, +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index b3d72df05..ddfe0b178 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass diff --git a/docker/version.txt b/docker/version.txt index 40fee9d81..b480e0254 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251209d \ No newline at end of file +nightly-dev-20251212a \ No newline at end of file diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index 28f2f88f1..e1c24e3ad 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -16,7 +16,7 @@ For instructions on setting up the environment and downloading data, please refe To prepare the DeepSeek R1 checkpoint, first you will need to download DeepSeek-R1 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): ```bash -huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 +hf download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 ``` The Hugging Face checkpoint for DeepSeek-R1 is in a block-quantized fp8 format. To convert it into a torch_dist format that Megatron can load, you first need to convert it to a bf16 Hugging Face checkpoint: diff --git a/docs/en/examples/glm4-9B.md b/docs/en/examples/glm4-9B.md index 5c4917e95..f46e9f373 100644 --- a/docs/en/examples/glm4-9B.md +++ b/docs/en/examples/glm4-9B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 +hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -110,7 +110,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Number of training steps corresponding to one rollout --num-steps-per-rollout 1 @@ -129,7 +129,7 @@ EVAL_ARGS=( --eval-prompt-data /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) ``` diff --git a/docs/en/examples/qwen3-30B-A3B.md b/docs/en/examples/qwen3-30B-A3B.md index e51de0d6d..965ef7eb5 100644 --- a/docs/en/examples/qwen3-30B-A3B.md +++ b/docs/en/examples/qwen3-30B-A3B.md @@ -79,7 +79,7 @@ Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B. miles also supports BF16 training with FP8 inference. For the Qwen3-30B-A3B model, you just need to download the following model: ```bash -huggingface-cli download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 +hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 ``` And replace `--hf-checkpoint` with: diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 27d2a4c1c..1966fd823 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` @@ -110,7 +110,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Number of training steps corresponding to one rollout --num-steps-per-rollout 1 @@ -129,7 +129,7 @@ EVAL_ARGS=( --eval-prompt-data /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) ``` diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index b562547fb..db07ab705 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -39,13 +39,12 @@ docker run --rm --gpus all --ipc=host --shm-size=16g \ ### Install miles -After entering the Docker container, please follow these steps to clone the miles repository and install it: +miles is already installed in the docker image. To update to the latest version, please execute the following command: ```bash # Path can be adjusted according to actual situation -cd /root/ -git clone https://github.com/radixark/miles.git -cd miles +cd /root/miles +git pull pip install -e . ``` @@ -54,8 +53,6 @@ pip install -e . You can download required models and datasets from platforms like Hugging Face, ModelScope, etc. Here are the commands to download example resources using `huggingface_hub`: ```bash -pip install -U huggingface_hub - # Download model weights (GLM-Z1-9B) hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 @@ -203,7 +200,7 @@ ROLLOUT_ARGS=( # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Load balancing for data collected in rollout phase. It ensures that the computational workload allocated to each training process (DP rank) is roughly equal, which may be beneficial for training speed --balance-data @@ -225,7 +222,7 @@ EVAL_ARGS=( # Maximum response length during evaluation --eval-max-response-len 16384 # Sampling parameters during evaluation - --eval-top-p 0.7 + --eval-top-p 1 ) ``` @@ -569,7 +566,17 @@ ray job submit --address="http://127.0.0.1:8265" \ --... # Other Megatron/SGLang/miles arguments ``` +Optionally, the following environment variables may be needed based on your environment. For example, when there are multiple IPs and the wrong one is chosen in a Docker or SLURM envionment. We provide an example used in a SLURM + enroot multi-node system as follows: + +``` +export MILES_HOST_IP=$(hostname -I | awk '{print $1}') +export GLOO_SOCKET_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\\./ {print $2}') +export NCCL_SOCKET_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\\./ {print $2}') +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}') +``` + miles has been deeply optimized for distributed training of large-scale Mixture of Experts (MoE) models. We provide some end-to-end training cases for reference: -- [Example: 64xH100 Training GLM-4.5](models/glm4.5-355B-A32B.md) -- [Example: 128xH100 Training DeepSeek-R1](models/deepseek-r1.md) +- [Example: 64xH100 Training GLM-4.5](../examples/glm4.5-355B-A32B.md) +- [Example: 128xH100 Training DeepSeek-R1](../examples/deepseek-r1.md) +- The scripts such as `scripts/run_qwen3_30b_a3b.py`, `scripts/run_glm45_355b_a32b.py` also support multi-node training, though there are little documentations about it currently. diff --git a/docs/en/platform_support/amd_tutorial.md b/docs/en/platform_support/amd_tutorial.md index 6c2def2d9..790aede5d 100644 --- a/docs/en/platform_support/amd_tutorial.md +++ b/docs/en/platform_support/amd_tutorial.md @@ -151,7 +151,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -162,7 +162,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/eval/scripts/run-qwen3-32B.sh b/examples/eval/scripts/run-qwen3-32B.sh index 0a235aeea..eb6702deb 100644 --- a/examples/eval/scripts/run-qwen3-32B.sh +++ b/examples/eval/scripts/run-qwen3-32B.sh @@ -54,7 +54,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/eval/scripts/run-qwen3-4B.sh b/examples/eval/scripts/run-qwen3-4B.sh index 64b8c5cb1..4343377f1 100644 --- a/examples/eval/scripts/run-qwen3-4B.sh +++ b/examples/eval/scripts/run-qwen3-4B.sh @@ -54,7 +54,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/eval_multi_task/multi_task.sh b/examples/eval_multi_task/multi_task.sh index eb028d91e..8236e6d2f 100644 --- a/examples/eval_multi_task/multi_task.sh +++ b/examples/eval_multi_task/multi_task.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/formal_math/single_round/run.py b/examples/formal_math/single_round/run.py index c688fe341..8cbb9d738 100644 --- a/examples/formal_math/single_round/run.py +++ b/examples/formal_math/single_round/run.py @@ -57,7 +57,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -77,7 +77,7 @@ def execute(): "--eval-interval 20 " "--n-samples-per-eval-prompt 1 " f"--eval-max-response-len {eval_max_response_len or 16384} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) if mode == "eval_flc": diff --git a/examples/formal_math/single_round/run_minimal.py b/examples/formal_math/single_round/run_minimal.py index dad439c2b..469d6b833 100644 --- a/examples/formal_math/single_round/run_minimal.py +++ b/examples/formal_math/single_round/run_minimal.py @@ -33,7 +33,7 @@ "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " "--num-rollout 3000 " @@ -43,7 +43,7 @@ "--eval-interval 20 " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 16384 " - "--eval-top-p 0.7 " + "--eval-top-p 1 " "--eval-prompt-data " "minif2f /root/datasets/formal_math_single_round/minimal_demo/minif2f_test.jsonl " ) diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 63e7e7818..026e48608 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -52,7 +52,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/geo3k_vlm/README.md b/examples/geo3k_vlm/README.md index 4cf68cf58..1946999dd 100644 --- a/examples/geo3k_vlm/README.md +++ b/examples/geo3k_vlm/README.md @@ -29,4 +29,7 @@ All three performed similarly, so we use the default math RM for simplicity. Our initial geo3k-specific verifier produced "format scores" (**0 and 0.9**) instead of clean binary rewards. Under **fp32**, fractional values like 0.9 can't be exactly represented, so when all samples in a group have the same reward, `reward - mean` doesn't equal zero—creating spurious gradient signal. -We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `miles/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts. \ No newline at end of file +We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `miles/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts. + +## B200 +Blackwell currently does not support fa3, we need to use `--sglang-mm-attention-backend sdpa` and `--attn-implementation flash_attention_2` \ No newline at end of file diff --git a/examples/geo3k_vlm/run_geo3k_vlm.py b/examples/geo3k_vlm/run_geo3k_vlm.py index 6f5a9c59e..0106d2beb 100644 --- a/examples/geo3k_vlm/run_geo3k_vlm.py +++ b/examples/geo3k_vlm/run_geo3k_vlm.py @@ -8,7 +8,6 @@ NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "1")) EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) -MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1") def prepare(): @@ -35,12 +34,12 @@ def execute(): "--rollout-batch-size 64 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 4096 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 512 " ) eval_args = ( - # "--eval-interval 20 " + "--eval-interval 20 " "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 4096 " @@ -119,27 +118,6 @@ def execute(): # f"{true_on_policy_args} " ) - # Kill existing processes - U.exec_command( - "pkill -9 sglang; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'ray stop --force; '}" - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "pkill -9 redis; " - "true; " - ) - - if not EXTERNAL_RAY: - # Start Ray - U.exec_command( - f"export PYTHONBUFFERED=16 && " - f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} " - f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" - ) # Submit Ray job execute_train( train_args=train_args, diff --git a/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh b/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh index 291fd8cb9..e761a3acd 100644 --- a/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh +++ b/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh @@ -49,7 +49,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 16 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 128 --balance-data @@ -60,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime "${BASE_DIR}/aime-2024.jsonl" --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/low_precision/run-qwen3-4b-fp8.sh b/examples/low_precision/run-qwen3-4b-fp8.sh index 22e7d2e1f..b196ba606 100644 --- a/examples/low_precision/run-qwen3-4b-fp8.sh +++ b/examples/low_precision/run-qwen3-4b-fp8.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/data/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/multi_agent/README.md b/examples/multi_agent/README.md index 0b6bf49fa..c974640bc 100644 --- a/examples/multi_agent/README.md +++ b/examples/multi_agent/README.md @@ -45,7 +45,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-context-len 16384 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh b/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh index f8e3952da..f3e5f1466 100644 --- a/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh +++ b/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-context-len 16384 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -60,7 +60,7 @@ EVAL_ARGS=( # --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/on_policy_distillation/run-qwen3-8B-opd.sh b/examples/on_policy_distillation/run-qwen3-8B-opd.sh index bb63a8513..c57b9eef4 100644 --- a/examples/on_policy_distillation/run-qwen3-8B-opd.sh +++ b/examples/on_policy_distillation/run-qwen3-8B-opd.sh @@ -63,7 +63,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 16 --n-samples-per-prompt 4 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 --balance-data @@ -80,7 +80,7 @@ EVAL_ARGS=( # --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl # --n-samples-per-eval-prompt 16 # --eval-max-response-len 16384 - # --eval-top-p 0.7 + # --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh index 6cf2af46c..9fab2cd2f 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh @@ -34,7 +34,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 1024 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 ) diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index e3e8f0b2e..838ce0e2c 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/strands-agents/strands_qwen3_4b.sh b/examples/strands-agents/strands_qwen3_4b.sh index 2df6ff0dd..647c8e2f5 100644 --- a/examples/strands-agents/strands_qwen3_4b.sh +++ b/examples/strands-agents/strands_qwen3_4b.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/data/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/run_qwen3_4B.sh index 25eb22545..a13067d0c 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/run_qwen3_4B.sh @@ -42,7 +42,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 1024 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std --balance-data diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh index d8cac2e87..300e8ac75 100644 --- a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 1 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/true_on_policy/run_simple.py b/examples/true_on_policy/run_simple.py index 6112877f9..1b472b806 100644 --- a/examples/true_on_policy/run_simple.py +++ b/examples/true_on_policy/run_simple.py @@ -32,7 +32,7 @@ def execute(): f"--rollout-batch-size {1 if MODE == 'debug_one_sample' else 32} " f"--n-samples-per-prompt {1 if MODE == 'debug_one_sample' else 8} " f"--rollout-max-response-len {2 if MODE == 'debug_one_sample' else 1024} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # temp remove this to make test easier # "--over-sampling-batch-size 64 " # "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " diff --git a/examples/true_on_policy_vlm/README.md b/examples/true_on_policy_vlm/README.md index 246665311..786ec2c98 100644 --- a/examples/true_on_policy_vlm/README.md +++ b/examples/true_on_policy_vlm/README.md @@ -5,6 +5,7 @@ This example demonstrates true on-policy training with Qwen3-VL dense model on F

Training Inference Log Prob Diff

+ ## Usage ```bash diff --git a/examples/true_on_policy_vlm/run_simple.py b/examples/true_on_policy_vlm/run_simple.py index 89b8ec229..3f6e17541 100644 --- a/examples/true_on_policy_vlm/run_simple.py +++ b/examples/true_on_policy_vlm/run_simple.py @@ -8,7 +8,6 @@ NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "1")) EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) -MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1") def prepare(): @@ -34,12 +33,12 @@ def execute(): "--rollout-batch-size 64 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 4096 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 512 " ) eval_args = ( - # "--eval-interval 20 " + "--eval-interval 20 " "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 4096 " @@ -127,28 +126,6 @@ def execute(): f"{true_on_policy_args} " ) - # Kill existing processes - U.exec_command( - "pkill -9 sglang; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'ray stop --force; '}" - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "pkill -9 redis; " - "true; " - ) - - if not EXTERNAL_RAY: - # Start Ray - U.exec_command( - f"export PYTHONBUFFERED=16 && " - f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} " - f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" - ) - # Submit Ray job execute_train( train_args=train_args, diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 4ff6988f1..1e3e5b3ae 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -80,7 +80,8 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty if i == dist.get_rank(): self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True) self.tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True) - if self.args.multimodal_keys: + # Vision models have `vision_config` in the config + if hasattr(self.hf_config, "vision_config"): self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True) dist.barrier(group=get_gloo_group()) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 3b60157e1..bcbdd1a42 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -71,6 +71,12 @@ def init( self.train_parallel_config = { "dp_size": mpu.get_data_parallel_world_size(with_context_parallel=False), } + dist.barrier(group=get_gloo_group()) + + if args.offload_train: + if (x := args.train_memory_margin_bytes) > 0: + logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") + torch_memory_saver.memory_margin_bytes = x if self.args.debug_rollout_only: return 0 @@ -186,18 +192,37 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] ] + + if self.args.qkv_format == "bshd": + # TODO: micro-batch wise dynamic, possibly move to @data.py:get_data_iterator + max_seq_len = max(rollout_data["total_lengths"]) + + # pad to reduce memory fragmentation and maybe make the computation faster + pad_size = mpu.get_tensor_model_parallel_world_size() * self.args.data_pad_size_multiplier + max_seq_len = (max_seq_len + pad_size - 1) // pad_size * pad_size + + rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) + if "rollout_log_probs" in rollout_data: rollout_data["rollout_log_probs"] = [ torch.tensor( - slice_log_prob_with_cp(log_prob, total_length, response_length), + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ), device=torch.cuda.current_device(), dtype=torch.float32, ) - for log_prob, total_length, response_length in zip( - rollout_data["rollout_log_probs"], - rollout_data["total_lengths"], - rollout_data["response_lengths"], - strict=False, + for i, (log_prob, total_length, response_length) in enumerate( + zip( + rollout_data["rollout_log_probs"], + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, + ) ) ] if "rollout_routed_experts" in rollout_data: diff --git a/miles/backends/megatron_utils/cp_utils.py b/miles/backends/megatron_utils/cp_utils.py index 92baa6954..2e795d3d3 100644 --- a/miles/backends/megatron_utils/cp_utils.py +++ b/miles/backends/megatron_utils/cp_utils.py @@ -9,6 +9,8 @@ def get_logits_and_tokens_offset_with_cp( total_length: int, response_length: int, + qkv_format: str = "thd", + max_seq_len: int | None = None, ): """ All offsets start from the begining of the prompt. @@ -18,7 +20,11 @@ def get_logits_and_tokens_offset_with_cp( assert cp_size > 1 prompt_length = total_length - response_length - chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) + if qkv_format == "thd": + chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) + else: + assert max_seq_len is not None, "max_seq_len must be provided for qkv_format=bshd" + chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) # the offset of 2 chunks chunk_0 = (cp_rank * chunk_size, (cp_rank + 1) * chunk_size) @@ -49,6 +55,8 @@ def get_sum_of_sample_mean( response_lengths: list[int], loss_masks: list[torch.Tensor], calculate_per_token_loss: bool = False, + qkv_format: str = "thd", + max_seq_lens: list[int] | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP @@ -78,8 +86,11 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: for i, (total_length, response_length, loss_mask) in enumerate( zip(total_lengths, response_lengths, loss_masks, strict=False) ): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None prompt_length = total_length - response_length - _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( + total_length, response_length, qkv_format, max_seq_len + ) loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0)) @@ -160,23 +171,44 @@ def zero(len: int) -> torch.Tensor: return full_tensor -def slice_with_cp(tokens: torch.Tensor, pad_value: tuple[int, float, Callable]) -> torch.Tensor: +def slice_with_cp( + tokens: torch.Tensor, + pad_value: tuple[int, float, Callable], + qkv_format: str = "thd", + max_seq_len: int | None = None, +) -> torch.Tensor: cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() + if qkv_format == "bshd": + assert max_seq_len is not None + + def pad_tokens(tokens, pad): + if isinstance(pad_value, Callable): + pad_func = pad_value + tokens = pad_func(tokens, pad) + else: + # pad on the first dimension + pad_tuple = (0, 0) * (tokens.dim() - 1) + (0, pad) + tokens = F.pad(tokens, pad_tuple, value=pad_value) + return tokens + if cp_size == 1: + if qkv_format == "bshd": + pad = max_seq_len - tokens.size(0) + tokens = pad_tokens(tokens, pad) return tokens - # pad - chunk_size = (len(tokens) + 2 * cp_size - 1) // (2 * cp_size) - pad = 2 * cp_size * chunk_size - len(tokens) - if isinstance(pad_value, Callable): - pad_func = pad_value - tokens = pad_func(tokens, pad) + token_len = len(tokens) + if qkv_format == "thd": + chunk_size = (token_len + 2 * cp_size - 1) // (2 * cp_size) else: - # pad on the first dimension - pad_tuple = (0, 0) * (tokens.dim() - 1) + (0, pad) - tokens = F.pad(tokens, pad_tuple, value=pad_value) + chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) + + # pad + pad = 2 * cp_size * chunk_size - token_len + tokens = pad_tokens(tokens, pad) + # get 2 chunk for thd cp start_1, end_1 = chunk_size * cp_rank, chunk_size * (cp_rank + 1) start_2, end_2 = chunk_size * (2 * cp_size - cp_rank - 1), chunk_size * (2 * cp_size - cp_rank) @@ -187,6 +219,8 @@ def slice_log_prob_with_cp( log_prob: list[float] | torch.Tensor, total_length: int, response_length: int, + qkv_format: str = "thd", + max_token_len: int | None = None, ) -> list[float] | torch.Tensor: assert len(log_prob) == response_length @@ -196,7 +230,9 @@ def slice_log_prob_with_cp( return log_prob prompt_length = total_length - response_length - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( + total_length, response_length, qkv_format, max_token_len + ) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] chunk_2 = log_prob[logits_offset[1][0] - (prompt_length - 1) : logits_offset[1][1] - (prompt_length - 1)] diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index 73300bed6..f94d1b7e0 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -23,9 +23,7 @@ def get_batch( - data_iterator: "DataIterator", - keys: Sequence[str], - pad_multiplier: int = 128, + data_iterator: "DataIterator", keys: Sequence[str], pad_multiplier: int = 128, qkv_format: str = "thd" ) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]: """ Generate a CP-ready micro-batch with packed sequence parameters. @@ -55,39 +53,50 @@ def get_batch( tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 + pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens cp_size = mpu.get_context_parallel_world_size() - tokens = [slice_with_cp(t, pad_token_id) for t in tokens] - - cu_seqlens = [0] - for t in tokens: - cu_seqlens.append(cu_seqlens[-1] + t.size(0)) - tokens = torch.cat(tokens) + if qkv_format == "bshd": + max_seqlen = batch["max_seq_lens"][0] + assert max([t.size(0) for t in tokens]) <= max_seqlen + tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens] + tokens = torch.stack(tokens) + + elif qkv_format == "thd": + tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] + + cu_seqlens = [0] + for t in tokens: + cu_seqlens.append(cu_seqlens[-1] + t.size(0)) + + tokens = torch.cat(tokens) + + # Always pad to reduce memory fragmentation and maybe make the computation faster + pad = (pad_size - tokens.size(0) % pad_size) % pad_size + if pad != 0: + tokens = F.pad(tokens, (0, pad), value=pad_token_id) + cu_seqlens.append(cu_seqlens[-1] + pad) + + # thd requires the cu_seqlens to be of the origin length + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format="thd", + ) - # Always pad to reduce memory fragmentation and maybe make the computation faster - pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier - pad = (pad_size - tokens.size(0) % pad_size) % pad_size - if pad != 0: - tokens = F.pad(tokens, (0, pad), value=pad_token_id) - cu_seqlens.append(cu_seqlens[-1] + pad) - - # thd requires the cu_seqlens to be of the origin length - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) + tokens = tokens.unsqueeze(0) + else: + raise ValueError(f"Unsupported qkv_format: {qkv_format}") - tokens = tokens.unsqueeze(0) batch["tokens"] = tokens batch["packed_seq_params"] = packed_seq_params return batch @@ -311,6 +320,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc response_lengths = rollout_data["response_lengths"] loss_masks = rollout_data["loss_masks"] total_lengths = rollout_data["total_lengths"] + max_seq_lens = rollout_data.get("max_seq_lens", None) for key, val in rollout_data.items(): if key in [ @@ -318,6 +328,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "loss_masks", "sample_indices", "rollout_routed_experts", + "max_seq_lens", ]: continue # Upload per sample mean for each rollout value @@ -329,7 +340,13 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: - sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) val = cp_size * sum_of_sample_mean(val) / len(loss_masks) else: val = val.mean() * cp_size diff --git a/miles/backends/megatron_utils/initialize.py b/miles/backends/megatron_utils/initialize.py index 7c4357436..e9f062c11 100644 --- a/miles/backends/megatron_utils/initialize.py +++ b/miles/backends/megatron_utils/initialize.py @@ -7,7 +7,6 @@ from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator from megatron.training.global_vars import _build_tokenizer, set_args - logger = logging.getLogger(__name__) diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index be61684e5..d7b72a512 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -4,6 +4,7 @@ import torch from megatron.core import mpu +from torch.utils.checkpoint import checkpoint from miles.utils.distributed_utils import distributed_masked_whiten from miles.utils.misc import load_function @@ -30,6 +31,7 @@ def get_responses( unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], + max_seq_lens: list[int] | None = None, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Yield response-aligned `(logits_chunk, tokens_chunk)` pairs per sample. @@ -52,24 +54,40 @@ def get_responses( `[R, V]` (policy) or `[R, 1]` (value) and `tokens_chunk` is shape `[R]` (1D int64), both aligned to response tokens for one sample. """ - assert logits.size(0) == 1, f"{logits.shape}" + qkv_format = args.qkv_format + assert logits.dtype == torch.float32, f"{logits.dtype}" + assert len(logits.shape) == 3, f"{logits.shape}" + + if qkv_format == "thd": + assert logits.size(0) == 1, f"{logits.shape}" + logits = logits.squeeze(0) + else: + assert max_seq_lens is not None + logits = logits.view(-1, logits.size(-1)) - logits = logits.squeeze(0) logits = logits.div(args.rollout_temperature) cp_size = mpu.get_context_parallel_world_size() end = 0 - for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths, strict=False): + for i, (tokens, total_length, response_length) in enumerate( + zip(unconcat_tokens, total_lengths, response_lengths, strict=False) + ): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + if cp_size == 1: - end += total_length - start = end - response_length + if qkv_format == "bshd": + end = max_seq_len * i + total_length + start = end - response_length + else: + end += total_length + start = end - response_length logits_chunk = logits[start - 1 : end - 1] tokens_chunk = tokens[-response_length:] else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length + total_length, response_length, qkv_format, max_seq_len ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -99,6 +117,7 @@ def get_log_probs_and_entropy( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, + max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -131,6 +150,7 @@ def get_log_probs_and_entropy( unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, + max_seq_lens=max_seq_lens, ): log_prob, entropy = calculate_log_probs_and_entropy( logits_chunk, tokens_chunk, mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy @@ -156,6 +176,7 @@ def get_values( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, + max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: """Extract per-token value predictions over response tokens. @@ -183,6 +204,7 @@ def get_values( unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, + max_seq_lens=max_seq_lens, ): assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}" value_list.append(logits_chunk.squeeze(-1)) @@ -220,6 +242,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) response_lengths: list[int] = rollout_data.get("response_lengths") loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") total_lengths: list[int] = rollout_data.get("total_lengths") + max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) # return when not the last pp stage. if log_probs is None and values is None: @@ -313,8 +336,11 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) total_len = total_lengths[i] response_len = response_lengths[i] prompt_len = total_len - response_len + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp(total_len, response_len) + _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( + total_len, response_len, args.qkv_format, max_seq_len + ) # Convert global offsets to response-space offsets s0, e0 = token_offsets[0] @@ -443,6 +469,7 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] + max_seq_lens = batch.get("max_seq_lens", None) log_probs_and_entropy = get_log_probs_and_entropy( logits, @@ -451,6 +478,7 @@ def policy_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=True, + max_seq_lens=max_seq_lens, ) log_probs = log_probs_and_entropy["log_probs"] @@ -528,7 +556,12 @@ def policy_loss_function( # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction # modified_response_masks will be sliced with cp in get_sum_of_sample_mean sum_of_sample_mean = get_sum_of_sample_mean( - total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss + total_lengths, + response_lengths, + modified_response_masks, + args.calculate_per_token_loss, + args.qkv_format, + batch.get("max_seq_lens", None), ) pg_loss = sum_of_sample_mean(pg_loss) @@ -625,6 +658,7 @@ def value_loss_function( unconcat_tokens=batch["unconcat_tokens"], total_lengths=batch["total_lengths"], response_lengths=batch["response_lengths"], + max_seq_lens=batch.get("max_seq_lens", None), ) values = torch.cat([value.flatten() for value in values["values"]], dim=0) @@ -683,6 +717,7 @@ def sft_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=False, + max_seq_lens=batch.get("max_seq_lens", None), ) log_probs = log_probs_and_entropy["log_probs"] @@ -738,28 +773,27 @@ def loss_function( batch["response_lengths"], batch["loss_masks"], args.calculate_per_token_loss, + args.qkv_format, + batch.get("max_seq_lens", None), ) - loss_function_kwargs = { - "args": args, - "batch": batch, - "logits": logits, - "sum_of_sample_mean": sum_of_sample_mean, - } - match args.loss_type: case "policy_loss": - loss, log = policy_loss_function(**loss_function_kwargs) + func = policy_loss_function case "value_loss": - loss, log = value_loss_function(**loss_function_kwargs) + func = value_loss_function case "sft_loss": - loss, log = sft_loss_function(**loss_function_kwargs) + func = sft_loss_function case "custom_loss": - custom_loss_function = load_function(args.custom_loss_function_path) - loss, log = custom_loss_function(**loss_function_kwargs) + func = load_function(args.custom_loss_function_path) case _: raise ValueError(f"Unknown loss type: {args.loss_type}") + if args.recompute_loss_function: + loss, log = checkpoint(func, args, batch, logits, sum_of_sample_mean) + else: + loss, log = func(args, batch, logits, sum_of_sample_mean) + # Here we need to divide by cp_size because to cancel the multiply in Megatron. if not args.calculate_per_token_loss: loss = ( diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 7e8c8770d..c4e182797 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -210,7 +210,10 @@ def forward_step( # Get the batch. batch = get_batch( - data_iterator, ["tokens", "total_lengths", "response_lengths"], args.data_pad_size_multiplier + data_iterator, + ["tokens", "total_lengths", "response_lengths", "max_seq_lens"], + args.data_pad_size_multiplier, + args.qkv_format, ) unconcat_tokens = batch["unconcat_tokens"] tokens = batch["tokens"] @@ -232,6 +235,7 @@ def forward_step( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=args.use_rollout_entropy, + max_seq_lens=batch.get("max_seq_lens", None), ) # Turn on evaluation mode which disables dropout. @@ -361,8 +365,10 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "advantages", "returns", "rollout_log_probs", + "max_seq_lens", ], args.data_pad_size_multiplier, + args.qkv_format, ) if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 04b3db1b6..5b7b3dd74 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -89,19 +89,19 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage # Define the decoder layer spec if use_te: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) else: transformer_layer_spec = get_gpt_layer_local_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) build_model_context = nullcontext diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index b300ad167..9ee0fbb8a 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -305,7 +305,6 @@ def _split_train_data_by_dp(self, data, dp_size): "multimodal_inputs", "response_lengths", "rewards", - "raw_reward", "truncated", "loss_masks", "round_number", @@ -321,6 +320,7 @@ def _split_train_data_by_dp(self, data, dp_size): rollout_data[key] = val # keys that need to be splited at train side for key in [ + "raw_reward", "total_lengths", ]: if key not in data: diff --git a/miles/ray/train_actor.py b/miles/ray/train_actor.py index 4244c022a..3d3923e9c 100644 --- a/miles/ray/train_actor.py +++ b/miles/ray/train_actor.py @@ -7,7 +7,6 @@ import ray import torch import torch.distributed as dist -from torch_memory_saver import torch_memory_saver import miles.utils.eval_config from miles.ray.ray_actor import RayActor @@ -53,11 +52,6 @@ def init(self, args, role, with_ref=False): self.role = role self.with_ref = with_ref - if (x := args.train_memory_margin_bytes) > 0: - logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") - assert args.offload_train - torch_memory_saver.memory_margin_bytes = x - torch.serialization.add_safe_globals([miles.utils.eval_config.EvalDatasetConfig]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 54ff574c9..ce6e47161 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -118,6 +118,13 @@ def add_train_arguments(parser): default="megatron", help="The backend for training.", ) + parser.add_argument( + "--qkv-format", + type=str, + choices=["thd", "bshd"], + default="thd", + help="The qkv layout for Megatron backend.", + ) parser.add_argument( "--true-on-policy-mode", action="store_true", @@ -133,8 +140,8 @@ def add_train_arguments(parser): parser.add_argument( "--train-memory-margin-bytes", type=int, - default=0, - help="Add margin for train memory allocation.", + default=1024**3, + help="Add margin for train memory allocation. By default we will reserve 1GB as margin.", ) parser.add_argument( "--disable-weights-backuper", @@ -148,6 +155,11 @@ def add_train_arguments(parser): default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) + parser.add_argument( + "--recompute-loss-function", + action="store_true", + help="Whether to disable recompute loss function to save memory during training.", + ) return parser @@ -1580,6 +1592,16 @@ def miles_validate_args(args): if args.prefill_num_servers is not None: assert not args.use_fault_tolerance, "fault tolerance is not supported when prefill_num_servers is set." + assert args.qkv_format in [ + "thd", + "bshd", + ], f"qkv_format {args.qkv_format} is not supported. (only 'thd' and 'bshd' are supported)" + if args.qkv_format == "bshd": + assert args.train_backend == "megatron", "bshd format is only supported for megatron backend." + assert ( + args.use_dynamic_batch_size is False + ), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead." + def hf_validate_args(args, hf_config): def equal(x, y): diff --git a/miles/utils/debug_utils/send_to_sglang.py b/miles/utils/debug_utils/send_to_sglang.py index c183edf79..0c2d4b57d 100644 --- a/miles/utils/debug_utils/send_to_sglang.py +++ b/miles/utils/debug_utils/send_to_sglang.py @@ -22,7 +22,7 @@ def main( Minimally send prompts to SGLang using OpenAI endpoints with arguments in the same format as main Miles. Example usage: - python -m miles.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 0.8 --rollout-top-p 0.7 + python -m miles.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 1 --rollout-top-p 1 """ async def _main_async(): diff --git a/miles/utils/flops_utils.py b/miles/utils/flops_utils.py index 57cf9de90..71cdd4c65 100644 --- a/miles/utils/flops_utils.py +++ b/miles/utils/flops_utils.py @@ -16,8 +16,8 @@ def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num def calculate_attention_flops(seqlen, num_attention_heads, head_dim): - # QK^T - flops = 2 * num_attention_heads * seqlen * seqlen * head_dim + # QK^T with causal + flops = 2 * num_attention_heads * seqlen * seqlen * head_dim // 2 # A*V flops += 2 * num_attention_heads * seqlen * seqlen * head_dim return flops @@ -31,8 +31,9 @@ def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): return 2 * seqlen * hidden_size * ffn_hidden_size * 3 -def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): - head_dim = hidden_size // num_attention_heads +def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size, head_dim): + if head_dim is None: + head_dim = hidden_size // num_attention_heads return ( calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups) + calculate_attention_flops(seqlen, num_attention_heads, head_dim) @@ -49,6 +50,7 @@ def calculate_fwd_flops( num_attention_heads = args.num_attention_heads num_query_groups = args.num_query_groups vocab_size = args.vocab_size + kv_channels = args.kv_channels total_flops = 0 @@ -82,6 +84,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, dense_ffn, + kv_channels, ) * num_dense_layers ) @@ -94,6 +97,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, moe_ffn, + kv_channels, ) * num_moe_layers ) diff --git a/scripts/run-deepseek-r1.sh b/scripts/run-deepseek-r1.sh index f8a3779b4..93e6c0f4b 100644 --- a/scripts/run-deepseek-r1.sh +++ b/scripts/run-deepseek-r1.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -60,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4-9B-4xgpu-radixtree.sh b/scripts/run-glm4-9B-4xgpu-radixtree.sh index dbdc01c22..bd14b6d2c 100755 --- a/scripts/run-glm4-9B-4xgpu-radixtree.sh +++ b/scripts/run-glm4-9B-4xgpu-radixtree.sh @@ -17,8 +17,6 @@ export PYTHONBUFFERED=16 export CUDA_VISIBLE_DEVICES=0,1,2,3 -WANDB_KEY=8920a59faeab83c97b55c3cbe78618f11d0a1821 - NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) if [ "$NVLINK_COUNT" -gt 0 ]; then HAS_NVLINK=1 @@ -51,7 +49,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -62,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4-9B.sh b/scripts/run-glm4-9B.sh index df613f187..b67523883 100644 --- a/scripts/run-glm4-9B.sh +++ b/scripts/run-glm4-9B.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4.5-355B-A32B.sh b/scripts/run-glm4.5-355B-A32B.sh index 8771d1b77..0deaf0b88 100644 --- a/scripts/run-glm4.5-355B-A32B.sh +++ b/scripts/run-glm4.5-355B-A32B.sh @@ -42,7 +42,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-kimi-k2-Instruct.sh b/scripts/run-kimi-k2-Instruct.sh index 723432d9b..86919eff4 100644 --- a/scripts/run-kimi-k2-Instruct.sh +++ b/scripts/run-kimi-k2-Instruct.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 # --global-batch-size 1024 @@ -64,7 +64,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-kimi-k2-Thinking.sh b/scripts/run-kimi-k2-Thinking.sh index b559599ae..e5006b3b5 100644 --- a/scripts/run-kimi-k2-Thinking.sh +++ b/scripts/run-kimi-k2-Thinking.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 # --global-batch-size 1024 @@ -64,7 +64,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-llama3.2-3B-Instruct-amd.sh b/scripts/run-llama3.2-3B-Instruct-amd.sh index 1e24a0660..a3036c304 100644 --- a/scripts/run-llama3.2-3B-Instruct-amd.sh +++ b/scripts/run-llama3.2-3B-Instruct-amd.sh @@ -62,7 +62,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -73,7 +73,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-mimo-7B-rl-eagle.sh b/scripts/run-mimo-7B-rl-eagle.sh index dec29d64d..092f25fef 100644 --- a/scripts/run-mimo-7B-rl-eagle.sh +++ b/scripts/run-mimo-7B-rl-eagle.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 1 --eval-max-response-len 8192 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-moonlight-16B-A3B.sh b/scripts/run-moonlight-16B-A3B.sh index 28234086a..ef695d398 100644 --- a/scripts/run-moonlight-16B-A3B.sh +++ b/scripts/run-moonlight-16B-A3B.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 4096 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -61,7 +61,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 4096 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-235B-A22B.sh b/scripts/run-qwen3-235B-A22B.sh index 7f4892ac6..e42e17ab2 100644 --- a/scripts/run-qwen3-235B-A22B.sh +++ b/scripts/run-qwen3-235B-A22B.sh @@ -58,7 +58,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 8 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 --balance-data @@ -69,7 +69,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 5ddfbce6e..19bc70927 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-32B.sh b/scripts/run-qwen3-32B.sh index 8bee49577..f6eb8240a 100644 --- a/scripts/run-qwen3-32B.sh +++ b/scripts/run-qwen3-32B.sh @@ -45,7 +45,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -56,7 +56,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh index 6710096fe..321a9712d 100755 --- a/scripts/run-qwen3-4B-amd.sh +++ b/scripts/run-qwen3-4B-amd.sh @@ -55,7 +55,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -66,7 +66,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B-fsdp.sh b/scripts/run-qwen3-4B-fsdp.sh index 42200d182..3c95442d5 100644 --- a/scripts/run-qwen3-4B-fsdp.sh +++ b/scripts/run-qwen3-4B-fsdp.sh @@ -51,7 +51,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 8 --n-samples-per-prompt 8 --rollout-max-response-len 4096 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 ) diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index f41a9f714..c7f01abd9 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B_4xgpu-radixtree.sh b/scripts/run-qwen3-4B_4xgpu-radixtree.sh index e16f93d7a..cee7adb6d 100644 --- a/scripts/run-qwen3-4B_4xgpu-radixtree.sh +++ b/scripts/run-qwen3-4B_4xgpu-radixtree.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data ) @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B_4xgpu.sh b/scripts/run-qwen3-4B_4xgpu.sh index aa893f03b..931ef0408 100755 --- a/scripts/run-qwen3-4B_4xgpu.sh +++ b/scripts/run-qwen3-4B_4xgpu.sh @@ -50,7 +50,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -61,7 +61,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-8B-amd.sh b/scripts/run-qwen3-8B-amd.sh index de814aab4..cdc18f5ca 100644 --- a/scripts/run-qwen3-8B-amd.sh +++ b/scripts/run-qwen3-8B-amd.sh @@ -63,7 +63,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -74,7 +74,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-next-80B-A3B.sh b/scripts/run-qwen3-next-80B-A3B.sh index f18df3690..d5d725124 100644 --- a/scripts/run-qwen3-next-80B-A3B.sh +++ b/scripts/run-qwen3-next-80B-A3B.sh @@ -56,7 +56,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -67,7 +67,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run_deepseek.py b/scripts/run_deepseek.py index 2d4a563b5..05fee4447 100644 --- a/scripts/run_deepseek.py +++ b/scripts/run_deepseek.py @@ -124,7 +124,7 @@ def train(args: ScriptArgs): "--num-rollout 3000 " "--rollout-batch-size 128 " "--n-samples-per-prompt 8 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # ------------ "--num-steps-per-rollout 4 " "--balance-data " @@ -139,7 +139,7 @@ def train(args: ScriptArgs): # sometimes disable eval to speed up debugging eval_args = "" if (args.mode != "debug_minimal") and args.enable_eval: - eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " + eval_args += "--eval-interval 20 " "--eval-top-p 1 " match args.task: case "dapo_aime": diff --git a/scripts/run_glm45_355b_a32b.py b/scripts/run_glm45_355b_a32b.py index ca91f9257..8015fdc4d 100644 --- a/scripts/run_glm45_355b_a32b.py +++ b/scripts/run_glm45_355b_a32b.py @@ -134,7 +134,7 @@ def train(args: ScriptArgs): # TODO enlarge "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # ------------ # TODO enlarge "--num-steps-per-rollout 1 " @@ -151,7 +151,7 @@ def train(args: ScriptArgs): # sometimes disable eval to speed up debugging eval_args = "" if (args.mode != "debug_minimal") and args.enable_eval: - eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " + eval_args += "--eval-interval 20 " "--eval-top-p 1 " match args.task: case "dapo_aime": diff --git a/scripts/run_mcore_fsdp.py b/scripts/run_mcore_fsdp.py index f0c76f91f..354624091 100644 --- a/scripts/run_mcore_fsdp.py +++ b/scripts/run_mcore_fsdp.py @@ -82,7 +82,7 @@ def execute(args: ScriptArgs): f"--rollout-batch-size {8 if args.mode == 'debug_minimal' else 64} " f"--n-samples-per-prompt {8 if args.mode == 'debug_minimal' else 16} " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 32768} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " f"--global-batch-size {64 if args.mode == 'debug_minimal' else 1024} " ) @@ -122,7 +122,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " f"--eval-max-response-len {eval_max_response_len} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) perf_args = ( diff --git a/scripts/run_qwen3_30b_a3b.py b/scripts/run_qwen3_30b_a3b.py index 6e226a613..26a374aa0 100644 --- a/scripts/run_qwen3_30b_a3b.py +++ b/scripts/run_qwen3_30b_a3b.py @@ -77,7 +77,7 @@ def execute(args: ScriptArgs): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -89,7 +89,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " "--eval-max-response-len 16384 " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) perf_args = ( diff --git a/scripts/run_qwen3_4b.py b/scripts/run_qwen3_4b.py index 8e074c2f9..d1aa63301 100644 --- a/scripts/run_qwen3_4b.py +++ b/scripts/run_qwen3_4b.py @@ -95,7 +95,7 @@ def execute(args: ScriptArgs): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -137,7 +137,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " f"--eval-max-response-len {eval_max_response_len} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) grpo_args = ( diff --git a/setup.py b/setup.py index 22c18d712..8c0794320 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def get_tag(self): setup( author="miles Team", name="miles", - version="0.2.0.post1", + version="0.2.1", packages=find_packages(include=["miles*", "miles_plugins*"]), include_package_data=True, install_requires=_fetch_requirements("requirements.txt"), diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index 117246bf1..c5c0838c5 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -33,7 +33,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index bf8453719..f18888c22 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -31,7 +31,7 @@ def execute(): "--rollout-batch-size 8 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 32 " "--balance-data " ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index d63cbe12c..6302aadb6 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -29,7 +29,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 181d425a6..1c55ccb20 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -28,7 +28,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index dd7bb21df..6967f9145 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -23,7 +23,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 4d3d8b94a..b3eb416b3 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -27,7 +27,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index b1eaca583..6b5f6b889 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -35,7 +35,7 @@ def execute(): "--rollout-batch-size 8 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 32 " "--balance-data " )