From 31006e4d6bfb8dd40ef6bbeed50a9c56b4147453 Mon Sep 17 00:00:00 2001 From: "miles.pr.bot" Date: Fri, 12 Dec 2025 08:07:09 +0800 Subject: [PATCH 01/21] update code --- examples/geo3k_vlm/README.md | 5 ++++- miles/backends/megatron_utils/loss.py | 22 ++++++++++------------ miles/ray/rollout.py | 2 +- miles/utils/arguments.py | 9 +++++++-- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/examples/geo3k_vlm/README.md b/examples/geo3k_vlm/README.md index 4cf68cf58..1946999dd 100644 --- a/examples/geo3k_vlm/README.md +++ b/examples/geo3k_vlm/README.md @@ -29,4 +29,7 @@ All three performed similarly, so we use the default math RM for simplicity. Our initial geo3k-specific verifier produced "format scores" (**0 and 0.9**) instead of clean binary rewards. Under **fp32**, fractional values like 0.9 can't be exactly represented, so when all samples in a group have the same reward, `reward - mean` doesn't equal zero—creating spurious gradient signal. -We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `miles/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts. \ No newline at end of file +We fixed this by switching to the default math RM with clean **binary 0/1 rewards**. If you encounter similar precision issues with non-binary rewards, you can change the reward tensor dtype from `torch.float` to `torch.float16` in `miles/ray/rollout.py` (`_post_process_rewards` method) to truncate precision artifacts. + +## B200 +Blackwell currently does not support fa3, we need to use `--sglang-mm-attention-backend sdpa` and `--attn-implementation flash_attention_2` \ No newline at end of file diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index be61684e5..52ae3a9f4 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -4,6 +4,7 @@ import torch from megatron.core import mpu +from torch.utils.checkpoint import checkpoint from miles.utils.distributed_utils import distributed_masked_whiten from miles.utils.misc import load_function @@ -740,26 +741,23 @@ def loss_function( args.calculate_per_token_loss, ) - loss_function_kwargs = { - "args": args, - "batch": batch, - "logits": logits, - "sum_of_sample_mean": sum_of_sample_mean, - } - match args.loss_type: case "policy_loss": - loss, log = policy_loss_function(**loss_function_kwargs) + func = policy_loss_function case "value_loss": - loss, log = value_loss_function(**loss_function_kwargs) + func = value_loss_function case "sft_loss": - loss, log = sft_loss_function(**loss_function_kwargs) + func = sft_loss_function case "custom_loss": - custom_loss_function = load_function(args.custom_loss_function_path) - loss, log = custom_loss_function(**loss_function_kwargs) + func = load_function(args.custom_loss_function_path) case _: raise ValueError(f"Unknown loss type: {args.loss_type}") + if args.recompute_loss_function: + loss, log = checkpoint(func, args, batch, logits, sum_of_sample_mean) + else: + loss, log = func(args, batch, logits, sum_of_sample_mean) + # Here we need to divide by cp_size because to cancel the multiply in Megatron. if not args.calculate_per_token_loss: loss = ( diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index b300ad167..9ee0fbb8a 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -305,7 +305,6 @@ def _split_train_data_by_dp(self, data, dp_size): "multimodal_inputs", "response_lengths", "rewards", - "raw_reward", "truncated", "loss_masks", "round_number", @@ -321,6 +320,7 @@ def _split_train_data_by_dp(self, data, dp_size): rollout_data[key] = val # keys that need to be splited at train side for key in [ + "raw_reward", "total_lengths", ]: if key not in data: diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 54ff574c9..b2d5cc13f 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -133,8 +133,8 @@ def add_train_arguments(parser): parser.add_argument( "--train-memory-margin-bytes", type=int, - default=0, - help="Add margin for train memory allocation.", + default=1024**3, + help="Add margin for train memory allocation. By default we will reserve 1GB as margin.", ) parser.add_argument( "--disable-weights-backuper", @@ -148,6 +148,11 @@ def add_train_arguments(parser): default="raw", help="The method to convert megatron weights to hugging face weights for SGLang.", ) + parser.add_argument( + "--recompute-loss-function", + action="store_true", + help="Whether to disable recompute loss function to save memory during training.", + ) return parser From 22f25278e71dbc0b5f7089b9bba9dc43bf7c1446 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Fri, 12 Dec 2025 12:20:10 -0800 Subject: [PATCH 02/21] [Hardware] AMD - MI350/MI355 dockerfile (#306) Co-authored-by: gramesh-amd Co-authored-by: yushengsu-thu Co-authored-by: arist12 Co-authored-by: jhinpan Co-authored-by: zyzshishui Co-authored-by: guapisolo Co-authored-by: sunxxuns Co-authored-by: zhaochenyang20 --- docker/Dockerfile.rocm_MI350-5 | 252 +++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 docker/Dockerfile.rocm_MI350-5 diff --git a/docker/Dockerfile.rocm_MI350-5 b/docker/Dockerfile.rocm_MI350-5 new file mode 100644 index 000000000..6dc1353f0 --- /dev/null +++ b/docker/Dockerfile.rocm_MI350-5 @@ -0,0 +1,252 @@ +#### Use the base image for ROCm 7 / gfx950 (MI355) + +# The Docker image built with this Dockerfile: +# Base image: ROCm 7 with vllm pre-built for gfx950 +# Target GPU: MI355 (gfx950) + + +FROM rocm/sgl-dev:rocm7-vllm-20250904 + +SHELL ["/bin/bash", "-ceuxo", "pipefail"] + +ARG MAX_JOBS=128 +ENV MAX_JOBS=${MAX_JOBS} + +# Set environment variables for gfx950 +ENV GPU_ARCH=gfx950 +ENV PYTORCH_ROCM_ARCH=gfx950 +ENV GPU_ARCH_LIST=gfx950 +ENV AMDGPU_TARGET=gfx950 + + +########################################### +##############1. Install AITER############# +########################################### +WORKDIR /app + +RUN pip uninstall -y aiter || true +RUN rm -rf aiter +RUN git clone https://github.com/ROCm/aiter.git \ + && cd aiter \ + && git checkout v0.1.7.post2 \ + && git submodule update --init --recursive \ + && GPU_ARCHS=gfx950 python setup.py develop +########################################### +########################################### +########################################### + + +########################################### +####2. Install TransformerEngine for gfx950 +########################################### +WORKDIR /app + +RUN rm -rf TransformerEngine +RUN git clone https://github.com/ROCm/TransformerEngine.git \ + && cd TransformerEngine \ + && git checkout 90c04bcdc3c109505b318f40a39680263af55edf \ + && git submodule update --init --recursive + +ENV NVTE_FRAMEWORK=pytorch +ENV NVTE_ROCM_ARCH=gfx950 +ENV NVTE_USE_HIPBLASLT=1 +ENV NVTE_USE_ROCM=1 +ENV CMAKE_PREFIX_PATH="/opt/rocm:/opt/rocm/hip:/usr/local:/usr" + +RUN cd TransformerEngine && pip install . -v +########################################### +########################################### +########################################### + + +######################################### +####3. Install Megatron-LM (NVIDIA version) +######################################### +WORKDIR /app + +RUN pip install "numpy>=1.21.0,<2.0" --force-reinstall + +RUN pip uninstall -y megatron-core || true +RUN rm -rf Megatron-LM +RUN git clone https://github.com/NVIDIA/Megatron-LM \ + && cd Megatron-LM \ + && git checkout 48406695c4efcf1026a7ed70bb390793918dd97b \ + && pip install -e . +######################################### +######################################### +######################################### + + +######################################## +############ 4. Install mbridge######### +######################################## +RUN pip install git+https://github.com/ISEEKYAN/mbridge.git --no-deps +######################################## +######################################## +######################################## + + +######################################## +######5. Install Ray#################### +######################################## +RUN pip uninstall ray -y || true +RUN pip install "ray[data,train,tune,serve]==2.47.1" +######################################## +######################################## +######################################## + + +######################################### +###6. Install torch_memory_saver######### +######################################### +RUN pip install torch_memory_saver +######################################### +######################################### + + +####################################### +####7. Install Apex for ROCm########### +####################################### +WORKDIR /app + +RUN pip uninstall -y apex || true +RUN rm -rf apex +RUN git clone https://github.com/ROCm/apex.git \ + && cd apex \ + && python setup.py install +####################################### +####################################### +####################################### + + +######################################## +###8. Install slime agent framework deps +######################################## +RUN pip install pydra_config==0.0.15 +RUN pip install together +RUN pip install google-generativeai +RUN pip install tensorboard +######################################## +######################################## +######################################## + + +######################################## +###9. Set performance environment vars## +######################################## +ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 +ENV SGLANG_USE_AITER=1 +ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 +ENV SGLANG_MOE_PADDING=1 +ENV SGLANG_SET_CPU_AFFINITY=1 +ENV SGLANG_ROCM_FUSED_DECODE_MLA=1 +ENV SGLANG_USE_ROCM700A=1 +ENV NCCL_MIN_NCHANNELS=112 +ENV VLLM_FP8_PADDING=1 +ENV VLLM_FP8_ACT_PADDING=1 +ENV VLLM_FP8_WEIGHT_PADDING=1 +ENV VLLM_FP8_REDUCE_CONV=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE=1 +ENV TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 +######################################## +######################################## +######################################## + + +########################################### +##############Install SGLang############### +########################################### +WORKDIR /app + +# Install prerequisites +RUN pip install IPython orjson python-multipart torchao==0.9.0 pybind11 + +# Clone SGLang +RUN pip uninstall -y sgl_kernel sglang || true +RUN rm -rf sglang +RUN git clone https://github.com/sgl-project/sglang.git \ + && cd sglang \ + && git checkout v0.5.6 + +# Build sgl-kernel for gfx950 +RUN cd sglang/sgl-kernel \ + && rm -f pyproject.toml \ + && mv pyproject_rocm.toml pyproject.toml \ + && AMDGPU_TARGET=gfx950 python setup_rocm.py install + +# Install SGLang +RUN cd sglang \ + && rm -rf python/pyproject.toml \ + && mv python/pyproject_other.toml python/pyproject.toml \ + && pip install -e "python[all_hip]" + +# Test SGLang installation +RUN python -c "import sglang; import sgl_kernel; print('SGLang + sgl_kernel: OK')" + +RUN python -m pip cache purge +########################################### +########################################### +########################################### + + +########################################### +#### APPLY PATCHES (gfx950/MI355) ######### +########################################### + +# Copy patches from slime repo +COPY amd_patch/latest /app/patch + +# Apply Megatron patches +RUN cd /app/Megatron-LM \ + && git apply /app/patch/amd_megatron_fused_kernels_init.patch \ + && git apply /app/patch/megatron.patch --3way \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi \ + && pip install -e . -v + +# Apply SGLang patch +RUN cd /app/sglang \ + && git apply /app/patch/sglang.patch || echo "Check patch compatibility with v0.5.6" \ + && if grep -R -n '^<<<<<<< ' .; then \ + echo "Patch failed to apply cleanly. Please resolve conflicts." && \ + exit 1; \ + fi + +# Copy MOE configs for gfx950/MI355 +RUN find /app/sglang/python/sglang/srt/layers/quantization/configs/ \ + /app/sglang/python/sglang/srt/layers/moe/fused_moe_triton/configs/ \ + -type f -name '*MI300X*' 2>/dev/null | while read f; do \ + cp "$f" "$(echo $f | sed 's/MI300X/MI300X_VF/')" 2>/dev/null || true; \ + cp "$f" "$(echo $f | sed 's/MI300X/MI355/')" 2>/dev/null || true; \ +done + +########################################### +########################################### +########################################### + + +######################################## +#### Install additional packages######## +######################################## +RUN pip install sglang-router --force-reinstall +######################################## +######################################## +######################################## + + +######################################## +# Fix click/ray incompatibility with Python 3.10 +######################################## +RUN pip install click==8.2.1 +######################################## +######################################## +######################################## + + +WORKDIR /app + +CMD ["/usr/bin/bash"] + From 61a319ae0b794f291f4e02529968dbe007ceb121 Mon Sep 17 00:00:00 2001 From: "miles.pr.bot" Date: Sat, 13 Dec 2025 13:37:21 +0800 Subject: [PATCH 03/21] update code --- build_conda.sh | 17 +- docker/README.md | 4 +- docker/patch/latest/megatron.patch | 4 +- docker/patch/latest/sglang.patch | 16 +- docker/patch/v0.5.6/megatron.patch | 869 ++++++++ docker/patch/v0.5.6/sglang.patch | 2053 +++++++++++++++++++ docker/version.txt | 2 +- examples/geo3k_vlm/run_geo3k_vlm.py | 24 +- examples/true_on_policy_vlm/README.md | 1 + examples/true_on_policy_vlm/run_simple.py | 25 +- miles/backends/fsdp_utils/actor.py | 3 +- miles/backends/megatron_utils/actor.py | 6 + miles/backends/megatron_utils/initialize.py | 1 - miles/ray/train_actor.py | 6 - miles/utils/flops_utils.py | 12 +- setup.py | 2 +- 16 files changed, 2970 insertions(+), 75 deletions(-) create mode 100644 docker/patch/v0.5.6/megatron.patch create mode 100644 docker/patch/v0.5.6/sglang.patch diff --git a/build_conda.sh b/build_conda.sh index ac1bb2091..46fc12df6 100644 --- a/build_conda.sh +++ b/build_conda.sh @@ -21,13 +21,13 @@ micromamba install -n miles cuda cuda-nvtx cuda-nvtx-dev nccl -c nvidia/label/cu micromamba install -n miles -c conda-forge cudnn -y # prevent installing cuda 13.0 for sglang -pip install cuda-python==12.9.1 -pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu129 +pip install cuda-python==13.1.0 +pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu129 # install sglang git clone https://github.com/sgl-project/sglang.git cd sglang -git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +git checkout 5e2cda6158e670e64b926a9985d65826c537ac82 # Install the python packages pip install -e "python[all]" @@ -39,7 +39,7 @@ pip install cmake ninja MAX_JOBS=64 pip -v install flash-attn==2.7.4.post1 --no-build-isolation pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps -pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" +pip install --no-build-isolation "transformer_engine[pytorch]==2.10.0" pip install flash-linear-attention==0.4.0 NVCC_APPEND_FLAGS="--threads 4" \ pip -v install --disable-pip-version-check --no-cache-dir \ @@ -50,7 +50,7 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM && git checkout ${MEGATRON_COMMIT} && \ pip install -e . -pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@dc6876905830430b5054325fa4211ff302169c6b --no-cache-dir --force-reinstall pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation pip install nvidia-modelopt[torch]>=0.37.0 --no-build-isolation @@ -60,6 +60,9 @@ git clone https://github.com/NVIDIA/Megatron-LM.git --recursive && \ cd Megatron-LM/ && git checkout core_v0.14.0 && \ pip install -e . +# https://github.com/pytorch/pytorch/issues/168167 +pip install nvidia-cudnn-cu12==9.16.0.29 + # install miles and apply patches # if miles does not exist locally, clone it @@ -76,6 +79,6 @@ fi # apply patch cd $BASE_DIR/sglang -git apply $MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch +git apply $MILES_DIR/docker/patch/v0.5.6/sglang.patch cd $BASE_DIR/Megatron-LM -git apply $MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch +git apply $MILES_DIR/docker/patch/v0.5.6/megatron.patch \ No newline at end of file diff --git a/docker/README.md b/docker/README.md index 92f559e72..156169c72 100644 --- a/docker/README.md +++ b/docker/README.md @@ -5,10 +5,10 @@ We will publish 2 kinds of docker images: 2. latest version, which aligns to `lmsysorg/sglang:latest`. current stable version is: -- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) +- sglang nightly-dev-20251208-5e2cda61 (5e2cda6158e670e64b926a9985d65826c537ac82), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) history versions: -- sglang v0.5.0rc0-cu126 (8ecf6b9d2480c3f600826c7d8fef6a16ed603c3f), megatron 48406695c4efcf1026a7ed70bb390793918dd97b +- sglang v0.5.5.post1 (303cc957e62384044dfa8e52d7d8af8abe12f0ac), megatron v0.14.0 (23e00ed0963c35382dfe8a5a94fb3cda4d21e133) The command to build: diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index 0526ac55a..3a56ff4c2 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -219,14 +219,14 @@ index 6aec66e6d..6ca48b55f 100644 mtp_loss = loss_mask * mtp_loss if self.training: diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py -index a36b67364..8739270f2 100644 +index a36b67364..ed8883e32 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): # TE FusedAdam will not accumulate step for empty param groups, so we need to # align the step across param groups. param_group["step"] = int(step) -+ if param_group["step"] is None: ++ if "step" in param_group and param_group["step"] is None: + del param_group["step"] # Grad scaler state. diff --git a/docker/patch/latest/sglang.patch b/docker/patch/latest/sglang.patch index 055e09096..de12cdd43 100644 --- a/docker/patch/latest/sglang.patch +++ b/docker/patch/latest/sglang.patch @@ -215,7 +215,7 @@ index 932f52aeb..79c6b664f 100644 hidden_states = self._communicate_simple_fn( diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py -index 3293a8a59..02999afd0 100644 +index 3293a8a59..a075b71ce 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -84,15 +84,12 @@ class RMSNorm(CustomOp): @@ -236,7 +236,7 @@ index 3293a8a59..02999afd0 100644 self.variance_epsilon = eps self.hidden_size = hidden_size self.variance_size_override = ( -@@ -105,15 +102,16 @@ class RMSNorm(CustomOp): +@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, @@ -255,7 +255,17 @@ index 3293a8a59..02999afd0 100644 return rms_norm_batch_invariant( x, self.weight.data, -@@ -179,17 +177,35 @@ class RMSNorm(CustomOp): + self.variance_epsilon, + ) + if residual is not None: ++ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). ++ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, diff --git a/docker/patch/v0.5.6/megatron.patch b/docker/patch/v0.5.6/megatron.patch new file mode 100644 index 000000000..3a56ff4c2 --- /dev/null +++ b/docker/patch/v0.5.6/megatron.patch @@ -0,0 +1,869 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/common.py b/megatron/core/dist_checkpointing/strategies/common.py +index 41c21d93d..ef80f72d6 100644 +--- a/megatron/core/dist_checkpointing/strategies/common.py ++++ b/megatron/core/dist_checkpointing/strategies/common.py +@@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): + msc = MultiStorageClientFeature.import_package() + return msc.torch.load(load_path, map_location='cpu') + else: +- return torch.load(load_path, map_location='cpu') ++ return torch.load(load_path, map_location='cpu', weights_only=False) + except FileNotFoundError as e: + err_msg = f'Common file {load_path} does not exist' + if MultiStorageClientFeature.is_enabled(): +diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py +index ccf5242a2..9b6d3e31f 100644 +--- a/megatron/core/dist_checkpointing/strategies/torch.py ++++ b/megatron/core/dist_checkpointing/strategies/torch.py +@@ -427,6 +427,15 @@ def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, li + _restore_dict_types(x_val, templ_val) + + ++@dataclass ++class MCoreMetadata(Metadata): ++ """Metadata with mcore specific data.""" ++ ++ # holds data related to flattened_range ++ # TODO: remove when flattened_range is properly removed ++ mcore_data: Optional[Dict[str, Dict[str, Any]]] = None # Mcore related data about each tensor ++ ++ + @dataclass(frozen=True) + class MCoreSavePlan(SavePlan): + """SavePlan with MCore specific data.""" +@@ -499,9 +508,10 @@ class MCoreSavePlanner(DefaultSavePlanner): + def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]: + """Merges MCore data for all plans.""" + global_plan, metadata = super().create_global_plan(all_plans) +- metadata.mcore_data = dict( ++ mcore_data = dict( + ChainMap(*(plan.mcore_data for plan in all_plans)) # type: ignore[arg-type] + ) ++ metadata = MCoreMetadata(mcore_data=mcore_data, **vars(metadata)) + return global_plan, metadata + + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: +@@ -556,10 +566,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + def _validate_global_shapes(self, metadata, sharded_tensors): + for sh_ten in sharded_tensors: + if sh_ten.key not in metadata.state_dict_metadata: +- raise KeyError( +- f"{sh_ten.key} from model not in state dict:" +- f" {sorted(metadata.state_dict_metadata.keys())}" +- ) ++ # raise KeyError( ++ # f"{sh_ten.key} from model not in state dict:" ++ # f" {sorted(metadata.state_dict_metadata.keys())}" ++ # ) ++ print(f"{sh_ten.key} from model not in state dict, will skip") ++ continue + loaded_shape = metadata.state_dict_metadata[sh_ten.key].size + expected_shape = self._expected_shape(sh_ten) + if loaded_shape != expected_shape: +@@ -589,7 +601,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): + tensor_metadata = self.metadata.state_dict_metadata + metadata_with_sizes = [ + (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) +- for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() ++ for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata + ] + try: + # Temporarily set sizes to expected shapes +@@ -918,6 +930,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): + planner=MCoreLoadPlanner( + shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, + allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, ++ allow_partial_load=True, + ), + ) + +diff --git a/megatron/core/distributed/__init__.py b/megatron/core/distributed/__init__.py +index fe26e8b43..4451f2776 100644 +--- a/megatron/core/distributed/__init__.py ++++ b/megatron/core/distributed/__init__.py +@@ -11,3 +11,15 @@ from .finalize_model_grads import finalize_model_grads + from .fsdp.mcore_fsdp_adapter import FullyShardedDataParallel + from .torch_fully_sharded_data_parallel import TorchFullyShardedDataParallel + from .torch_fully_sharded_data_parallel_config import TorchFullyShardedDataParallelConfig ++ ++# Backward compatibility patch for FSDP module reorganization ++import sys ++import importlib.util ++ ++spec = importlib.util.find_spec('megatron.core.distributed.fsdp.src.megatron_fsdp') ++if spec: ++ custom_fsdp = importlib.util.module_from_spec(spec) ++ spec.loader.exec_module(custom_fsdp) ++ sys.modules['megatron.core.distributed.custom_fsdp'] = custom_fsdp ++ if hasattr(custom_fsdp, 'MegatronFSDP'): ++ custom_fsdp.FullyShardedDataParallel = custom_fsdp.MegatronFSDP +diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py +index 7727efe1e..966fe652a 100644 +--- a/megatron/core/extensions/transformer_engine.py ++++ b/megatron/core/extensions/transformer_engine.py +@@ -366,6 +366,7 @@ class TELinear(te.pytorch.Linear): + ) + + for param in self.parameters(): ++ setattr(param, "parallel_mode", parallel_mode) + if is_expert: + # Reduce the gradient on the expert_data_parallel group for expert linear layers + setattr(param, "allreduce", not self.expert_parallel) +diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py +index 860ee64a9..80944b702 100755 +--- a/megatron/core/models/gpt/gpt_layer_specs.py ++++ b/megatron/core/models/gpt/gpt_layer_specs.py +@@ -79,6 +79,8 @@ def get_gpt_layer_with_transformer_engine_spec( + qk_l2_norm: Optional[bool] = False, + use_te_op_fuser: Optional[bool] = False, + use_kitchen: bool = False, ++ post_self_attn_layernorm: bool = False, ++ post_mlp_layernorm: bool = False, + ) -> ModuleSpec: + """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). + +@@ -178,9 +180,11 @@ def get_gpt_layer_with_transformer_engine_spec( + ), + ), + self_attn_bda=get_bias_dropout_add, ++ post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, ++ post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, + sharded_state_dict_keys_map={ + "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", + "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", +diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py +index 6aec66e6d..6ca48b55f 100644 +--- a/megatron/core/models/gpt/gpt_model.py ++++ b/megatron/core/models/gpt/gpt_model.py +@@ -355,6 +355,7 @@ class GPTModel(LanguageModule): + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, ++ mtp_kwargs: Optional[dict] = {}, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post +@@ -410,6 +411,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, ++ mtp_kwargs=mtp_kwargs, + ) + + def _postprocess( +@@ -431,6 +433,7 @@ class GPTModel(LanguageModule): + runtime_gather_output=None, + extra_block_kwargs=None, + inference_context=None, ++ mtp_kwargs={}, + ): + """Postprocesses decoder hidden states to generate logits or compute loss. + +@@ -446,7 +449,7 @@ class GPTModel(LanguageModule): + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + +- if mtp_in_postprocess: ++ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, +@@ -465,25 +468,37 @@ class GPTModel(LanguageModule): + if not self.post_process: + return hidden_states + +- if self.mtp_process: +- mtp_labels = labels.clone() ++ if self.mtp_process and mtp_kwargs.get('mtp_labels', None) is not None: ++ mtp_labels = mtp_kwargs['mtp_labels'].clone() ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ + hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) + hidden_states = hidden_states_list[0] + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(mtp_labels) ++ else: ++ # Otherwise, roll the loss_mask to keep up with the mtp_labels ++ loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) + for mtp_layer_number in range(self.config.mtp_num_layers): + # output +- mtp_logits, _ = self.output_layer( +- hidden_states_list[mtp_layer_number + 1], +- weight=output_weight, +- runtime_gather_output=runtime_gather_output, ++ output_layer_params = {k: v.detach() for k, v in self.output_layer.named_parameters()} ++ output_layer_buffers = dict(self.output_layer.named_buffers()) ++ mtp_logits, _ = torch.func.functional_call( ++ self.output_layer, ++ {**output_layer_params, **output_layer_buffers}, ++ (hidden_states_list[mtp_layer_number + 1],), ++ { ++ "weight": output_weight.detach() if output_weight else None, ++ "runtime_gather_output": runtime_gather_output, ++ }, + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. +- mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group) +- loss_mask, num_tokens = roll_tensor( +- loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group ++ mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ new_loss_mask, num_tokens = roll_tensor( ++ loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params + ) ++ loss_mask = new_loss_mask * loss_mask + mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: +diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py +index a36b67364..ed8883e32 100644 +--- a/megatron/core/optimizer/distrib_optimizer.py ++++ b/megatron/core/optimizer/distrib_optimizer.py +@@ -657,6 +657,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): + # TE FusedAdam will not accumulate step for empty param groups, so we need to + # align the step across param groups. + param_group["step"] = int(step) ++ if "step" in param_group and param_group["step"] is None: ++ del param_group["step"] + + # Grad scaler state. + if self.grad_scaler: +diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py +index a40c85a88..86688c331 100644 +--- a/megatron/core/parallel_state.py ++++ b/megatron/core/parallel_state.py +@@ -9,6 +9,7 @@ from typing import Callable, List, Optional + + import numpy as np + import torch ++import torch.distributed as dist + + from .utils import GlobalMemoryBuffer, is_torch_min_version + +diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py +index 63ee9d1f5..b90b744c1 100644 +--- a/megatron/core/pipeline_parallel/p2p_communication.py ++++ b/megatron/core/pipeline_parallel/p2p_communication.py +@@ -26,22 +26,22 @@ def _batched_p2p_ops( + ops = [] + if tensor_send_prev is not None: + send_prev_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group ++ torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, + ) + ops.append(send_prev_op) + if tensor_recv_prev is not None: + recv_prev_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, + ) + ops.append(recv_prev_op) + if tensor_send_next is not None: + send_next_op = torch.distributed.P2POp( +- torch.distributed.isend, tensor_send_next, next_pipeline_rank, group ++ torch.distributed.isend, tensor_send_next, next_pipeline_rank, + ) + ops.append(send_next_op) + if tensor_recv_next is not None: + recv_next_op = torch.distributed.P2POp( +- torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group ++ torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, + ) + ops.append(recv_next_op) + if len(ops) > 0: +diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py +index c749bac43..dde8d50e7 100644 +--- a/megatron/core/transformer/attention.py ++++ b/megatron/core/transformer/attention.py +@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") +- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) ++ if self.config.use_gated_attention: ++ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) ++ else: ++ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix="qkv") + + # =================================================== +@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): + # Output. [sq, b, h] + # ================= + ++ if self.config.use_gated_attention: ++ nvtx_range_push(suffix="sigmoid_gate") ++ core_attn_out = core_attn_out * torch.sigmoid(gate) ++ nvtx_range_pop(suffix="sigmoid_gate") ++ + nvtx_range_push(suffix="linear_proj") + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix="linear_proj") +@@ -879,19 +887,34 @@ class SelfAttention(Attention): + model_comm_pgs=model_comm_pgs, + ) + +- self.linear_qkv = build_module( +- submodules.linear_qkv, +- self.config.hidden_size, +- self.query_projection_size + 2 * self.kv_projection_size, +- config=self.config, +- init_method=self.config.init_method, +- gather_output=False, +- bias=self.config.add_bias_linear or self.config.add_qkv_bias, +- skip_bias_add=False, +- is_expert=False, +- tp_comm_buffer_name='qkv', +- tp_group=self.model_comm_pgs.tp, +- ) ++ if self.config.use_gated_attention: ++ self.linear_qgkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ 2 * (self.query_projection_size + self.kv_projection_size), ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) ++ else: ++ self.linear_qkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ self.query_projection_size + 2 * self.kv_projection_size, ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( +@@ -1036,6 +1059,65 @@ class SelfAttention(Attention): + + return query, key, value + ++ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 ++ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): ++ """ ++ Derives `query`, `key` and `value` tensors from `hidden_states`. ++ """ ++ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] ++ mixed_qgkv, _ = self.linear_qgkv(hidden_states) ++ ++ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] ++ new_tensor_shape = mixed_qgkv.size()[:-1] + ( ++ self.num_query_groups_per_partition, ++ ( ++ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) ++ * self.hidden_size_per_attention_head ++ ), ++ ) ++ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) ++ ++ split_arg_list = [ ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ self.hidden_size_per_attention_head, ++ self.hidden_size_per_attention_head, ++ ] ++ ++ if SplitAlongDim is not None: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) ++ else: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) ++ ++ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] ++ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) ++ gate = gate.reshape(query.size(0), query.size(1), -1) ++ ++ if self.q_layernorm is not None: ++ query = self.q_layernorm(query) ++ ++ if self.k_layernorm is not None: ++ key = self.k_layernorm(key) ++ ++ if self.config.test_mode: ++ self.run_realtime_tests() ++ ++ return query, gate, key, value ++ + def backward_dw(self) -> NoReturn: + """Execute weight update operations""" + self._backward_qkv_proj() +diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py +index 235b6f6af..fbcffe278 100644 +--- a/megatron/core/transformer/moe/moe_utils.py ++++ b/megatron/core/transformer/moe/moe_utils.py +@@ -566,6 +566,9 @@ def topk_routing_with_score_function( + else: + return torch.topk(scores, k=topk, dim=1) + ++ from miles.utils.routing_replay import get_routing_replay_compute_topk ++ compute_topk = get_routing_replay_compute_topk(compute_topk) ++ + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) +diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py +index 6b20b8622..459e65921 100644 +--- a/megatron/core/transformer/moe/router.py ++++ b/megatron/core/transformer/moe/router.py +@@ -156,6 +156,9 @@ class TopKRouter(Router): + self.local_tokens_per_expert = None + self.expert_bias = None + ++ from miles.utils.routing_replay import register_routing_replay ++ register_routing_replay(self) ++ + def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index b7884e18e..f0104f861 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union + + import torch + from torch import Tensor ++import warnings + + from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.mapping import ShardedStateDict +@@ -105,17 +106,21 @@ def tie_output_layer_state_dict( + ) + + +-def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): +- """Roll the tensor input along the sequence dimension with Context Parallelism (CP) support. + +- This function extends the original roll_tensor to support Context Parallelism, which allows +- MTP to work with CP > 1. When CP is enabled, the sequence dimension is split across CP ranks, +- and tensor rolling requires communication between adjacent CP ranks to properly handle the +- boundary conditions. ++def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): ++ """Roll the tensor input along the sequence dimension with Context Parallelism (CP) and Packed Sequence support. ++ ++ This function extends the original roll_tensor to support Context Parallelism and Packed Sequences. ++ When CP is enabled, the sequence dimension is split across CP ranks, and tensor rolling requires ++ communication between adjacent CP ranks to properly handle the boundary conditions. ++ When packed sequences are used, rolling is performed within each individual sequence boundary ++ to prevent mixing tokens between different packed sequences. + + For CP=1 (default behavior): Uses standard torch.roll with zero padding + For CP>1: Splits tensor into chunks, performs rolling within each chunk, then exchanges + boundary elements between adjacent CP ranks to maintain sequence continuity. ++ For packed sequences: Rolls tensors within sequence boundaries defined by cu_seqlens. ++ + + Args: + tensor (Tensor): The input tensor to roll. +@@ -123,9 +128,15 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + dims (int): The dimension to roll (typically -1 for sequence dimension). + cp_group (ProcessGroup): The context parallelism process group. If None or size=1, + falls back to standard rolling behavior. ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ If provided, rolling respects sequence boundaries. + Returns: + tuple: (rolled_tensor, sum_of_rolled_tensor) + """ ++ ++ if packed_seq_params is not None: ++ return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) ++ + # Standard rolling behavior when CP is not enabled (cp_group is None or size=1) + if cp_group is None or cp_group.size() == 1: + rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) +@@ -193,6 +204,103 @@ def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None): + + return rolled_tensor, rolled_tensor.sum() + ++def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): ++ """Roll tensor with packed sequence support. ++ ++ This function handles rolling for packed sequences by respecting sequence boundaries ++ defined in packed_seq_params.cu_seqlens. Rolling is performed within each individual ++ sequence to prevent mixing tokens between different packed sequences. When Context ++ Parallelism (CP) is enabled, each CP rank still receives the full `cu_seqlens` metadata ++ so we slice out the portion of every packed sequence that lives on the current rank and ++ reuse the standard CP boundary exchange to populate the rolling window. ++ ++ Args: ++ tensor (Tensor): The input tensor to roll. ++ shifts (int): The shift of the tensor (typically -1 for MTP). ++ dims (int): The dimension to roll (typically -1 for sequence dimension). ++ packed_seq_params (PackedSeqParams): Parameters for packed sequence processing. ++ cp_group (ProcessGroup): The context parallelism process group. ++ ++ Returns: ++ tuple: (rolled_tensor, sum_of_rolled_tensor) ++ """ ++ ++ # Notice: This is a naive implementation to test the correctness, a better solution will only sync the boundary tokens once. ++ assert dims == -1 or dims == tensor.dim() - 1, "Packed sequence roll only supports the last dimension." ++ assert shifts == -1, "Packed sequence roll only supports a single-token left shift." ++ cu_seqlens = packed_seq_params.cu_seqlens_q ++ assert cu_seqlens is not None, "Packed sequence parameters must provide cu_seqlens_q." ++ ++ rolled_tensor = tensor.clone() ++ ++ cp_size = cp_group.size() if cp_group is not None else 1 ++ if cp_size == 1: ++ # CP disabled: simply roll inside each packed sequence boundary. ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ seq_slice = tensor[..., start_idx:end_idx] ++ rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) ++ rolled_seq[..., shifts:] = 0 ++ rolled_tensor[..., start_idx:end_idx] = rolled_seq ++ return rolled_tensor, rolled_tensor.sum() ++ ++ # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). ++ local_rank = torch.distributed.get_rank(group=cp_group) ++ global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) ++ next_rank = global_ranks[(local_rank + 1) % cp_size] ++ prev_rank = global_ranks[(local_rank - 1) % cp_size] ++ ++ # iterate over each sequence individually ++ for i in range(len(cu_seqlens) - 1): ++ start_idx = cu_seqlens[i] ++ end_idx = cu_seqlens[i + 1] ++ ++ # the idx has been multiplied by cp_size, so we need to divide it by cp_size to get the local idx ++ local_start_idx = start_idx // cp_size ++ local_end_idx = end_idx // cp_size ++ tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() ++ ++ # The following code is very similar as the code in roll_tensor function ++ local_chunks = tensor_slice.chunk(2, dim=dims) ++ rolled_chunks = [ ++ torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks ++ ] ++ ++ tensor_send_list = [] ++ tensor_recv_list = [] ++ for chunk in rolled_chunks: ++ boundary = chunk.select(dims, shifts).contiguous().clone() ++ tensor_send_list.append(boundary) ++ tensor_recv_list.append(torch.empty_like(boundary)) ++ ++ ops = [] ++ if local_rank != 0: ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) ++ else: ++ tensor_recv_list[1].zero_() ++ ++ if local_rank != cp_size - 1: ++ ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) ++ ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) ++ else: ++ tensor_recv_list[0].copy_(tensor_send_list[1]) ++ ++ for op in ops: ++ op.wait() ++ ++ index = [slice(None)] * rolled_chunks[0].dim() ++ index[dims] = shifts ++ for chunk, recv in zip(rolled_chunks, tensor_recv_list): ++ chunk[tuple(index)] = recv ++ ++ seq_result = torch.cat(rolled_chunks, dim=dims) ++ ++ # update the rolled tensor ++ rolled_tensor[..., local_start_idx:local_end_idx] = seq_result ++ ++ return rolled_tensor, rolled_tensor.sum() + + class MTPLossLoggingHelper: + """Helper class for logging MTP losses.""" +@@ -480,9 +588,10 @@ class MultiTokenPredictionLayer(MegatronModule): + def _get_embeddings( + self, + input_ids: torch.Tensor, +- position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, ++ position_ids: Optional[torch.Tensor] = None, ++ packed_seq_params: Optional[PackedSeqParams] = None, + ): + """ + Preprocesses input data for the Multi-Token Prediction (MTP) layers. +@@ -499,12 +608,23 @@ class MultiTokenPredictionLayer(MegatronModule): + sequence length, b is the batch size, and h is the hidden size. + """ + # Calc logits for the current Multi-Token Prediction (MTP) layers. +- input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group) +- position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1, cp_group=self.cp_group) ++ input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) ++ ++ # Prepare/roll position ids only when applicable. ++ if position_ids is None: ++ # Fallback position ids for learned absolute embedding. ++ seq_len = input_ids.size(-1) ++ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) ++ position_ids = position_ids.unsqueeze(0).expand_as(input_ids) ++ ++ position_ids, _ = roll_tensor( ++ position_ids, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params ++ ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) ++ decoder_input = decoder_input.detach() + +- hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) ++ hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) + + return input_ids, position_ids, decoder_input, hidden_states + +@@ -604,22 +724,66 @@ class MultiTokenPredictionLayer(MegatronModule): + return hidden_states + + def _checkpointed_forward(self, forward_func, *args, **kwargs): ++ """Wrap `forward_func` with activation checkpointing while only passing tensors. ++ ++ Non-tensor arguments (e.g., configuration objects, None) are captured via closure so ++ that checkpoint implementations never receive them directly, avoiding save_for_backward ++ issues with non-tensor inputs. ++ """ ++ ++ # TODO(jiajun): Is there any better implementation here? ++ positional_specs = [] ++ kw_specs = [] ++ tensor_args: List[torch.Tensor] = [] ++ ++ for arg in args: ++ if torch.is_tensor(arg): ++ positional_specs.append(('tensor', len(tensor_args))) ++ tensor_args.append(arg) ++ else: ++ positional_specs.append(('const', arg)) ++ ++ for key, value in kwargs.items(): ++ if torch.is_tensor(value): ++ kw_specs.append((key, ('tensor', len(tensor_args)))) ++ tensor_args.append(value) ++ else: ++ kw_specs.append((key, ('const', value))) ++ ++ def run(*flat_tensor_args): ++ rebuilt_args = [] ++ for spec_type, payload in positional_specs: ++ if spec_type == 'tensor': ++ rebuilt_args.append(flat_tensor_args[payload]) ++ else: ++ rebuilt_args.append(payload) ++ ++ rebuilt_kwargs = {} ++ for key, (spec_type, payload) in kw_specs: ++ if spec_type == 'tensor': ++ rebuilt_kwargs[key] = flat_tensor_args[payload] ++ else: ++ rebuilt_kwargs[key] = payload ++ ++ return forward_func(*rebuilt_args, **rebuilt_kwargs) ++ ++ tensor_args_tuple = tuple(tensor_args) ++ + def checkpoint_handler(): +- """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" ++ """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`.""" + if self.config.fp8: + from megatron.core.extensions.transformer_engine import te_checkpoint + + return te_checkpoint( +- forward_func, ++ run, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), +- *args, +- **kwargs, ++ *tensor_args_tuple, + ) + else: + return tensor_parallel.checkpoint( +- forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() ++ run, self.config.distribute_saved_activations, *tensor_args_tuple + ) + + if self.config.recompute_method == 'uniform': +@@ -681,15 +845,13 @@ class MultiTokenPredictionLayer(MegatronModule): + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + assert context is None, f"multi token prediction + cross attention is not yet supported." +- assert ( +- packed_seq_params is None +- ), f"multi token prediction + sequence packing is not yet supported." + + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, + position_ids=position_ids, + embedding=embedding, + hidden_states=hidden_states, ++ packed_seq_params=packed_seq_params, + ) + + if self.config.recompute_granularity == 'full' and self.training: +diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py +index d55bebe7e..1eecbbd38 100644 +--- a/megatron/core/transformer/transformer_config.py ++++ b/megatron/core/transformer/transformer_config.py +@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): + qk_layernorm: bool = False + """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + ++ post_self_attn_layernorm: bool = False ++ post_mlp_layernorm: bool = False ++ use_gated_attention: bool = False ++ + test_mode: bool = False + """Whether to run real-time tests.""" + +diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py +index 84f22bdea..f0f3f8e86 100644 +--- a/megatron/core/transformer/transformer_layer.py ++++ b/megatron/core/transformer/transformer_layer.py +@@ -224,6 +224,7 @@ class TransformerLayerSubmodules: + input_layernorm: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + + pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp + cross_attention: Union[ModuleSpec, type] = IdentityOp +@@ -232,6 +233,7 @@ class TransformerLayerSubmodules: + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + mlp: Union[ModuleSpec, type] = IdentityOp + mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp ++ post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + + # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method + sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) +@@ -336,6 +338,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + ++ self.post_self_attn_layernorm = build_module( ++ submodules.post_self_attn_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon, ++ ) ++ + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = build_module( + submodules.pre_cross_attn_layernorm, +@@ -399,6 +408,13 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + ++ self.post_mlp_layernorm = build_module( ++ submodules.post_mlp_layernorm, ++ config=self.config, ++ hidden_size=self.config.hidden_size, ++ eps=self.config.layernorm_epsilon ++ ) ++ + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False +@@ -535,6 +551,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + attention_output_with_bias[0] + ) + ++ attention_output, attention_output_bias = attention_output_with_bias ++ attention_output = self.post_self_attn_layernorm(attention_output) ++ attention_output_with_bias = (attention_output, attention_output_bias) ++ + # TODO: could we move `bias_dropout_add_exec_handler` itself + # inside the module provided in the `bias_dropout_add_spec` module? + nvtx_range_push(suffix="self_attn_bda") +@@ -635,6 +655,10 @@ class TransformerLayer(MegatronModule, BaseTransformerLayer): + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) + ++ mlp_output, mlp_output_bias = mlp_output_with_bias ++ mlp_output = self.post_mlp_layernorm(mlp_output) ++ mlp_output_with_bias = (mlp_output, mlp_output_bias) ++ + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index e3459c5ee..7346bf35b 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -937,8 +937,6 @@ def validate_args(args, defaults={}): + # MoE Spec check + if args.num_experts == 0: + args.num_experts = None +- if args.num_experts is not None: +- assert args.spec is None, "Model Spec must be None when using MoEs" + if args.num_experts is not None and args.moe_ffn_hidden_size is None: + args.moe_ffn_hidden_size = args.ffn_hidden_size + print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") +@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): + if args.is_hybrid_model: + kw_args['is_hybrid_model'] = args.is_hybrid_model + ++ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm ++ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ kw_args['use_gated_attention'] = args.use_gated_attention ++ + # handle quantization config + # NOTE: Kitchen arguments are only added to the namespace when + # Kitchen library is available. +@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): + action='store_true', + help='If set, use original BERT residula connection ' + 'ordering.') ++ group.add_argument('--post-self-attn-layernorm', action='store_true', ++ help='If set, use post self attention layernorm.') ++ group.add_argument('--post-mlp-layernorm', action='store_true', ++ help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') + group.add_argument('--openai-gelu', action='store_true', + help='Use OpenAIs GeLU implementation. This option' + 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 5cf222ccc..d1554ca4c 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): + f"The transformers library must be installed to use huggingface_tokenizer_provider" + ) + ++ if "trust_remote_code" not in kwargs: ++ kwargs["trust_remote_code"] = True + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs diff --git a/docker/patch/v0.5.6/sglang.patch b/docker/patch/v0.5.6/sglang.patch new file mode 100644 index 000000000..de12cdd43 --- /dev/null +++ b/docker/patch/v0.5.6/sglang.patch @@ -0,0 +1,2053 @@ +diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py +index ef52bda7f..537d892dc 100644 +--- a/python/sglang/srt/disaggregation/decode.py ++++ b/python/sglang/srt/disaggregation/decode.py +@@ -296,6 +296,13 @@ class DecodePreallocQueue: + ) + return kv_manager + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + def add(self, req: Req, is_retracted: bool = False) -> None: + """Add a request to the pending queue.""" + if self._check_if_req_exceed_kv_capacity(req): +diff --git a/python/sglang/srt/disaggregation/mooncake/conn.py b/python/sglang/srt/disaggregation/mooncake/conn.py +index d4414d084..c5fb10155 100644 +--- a/python/sglang/srt/disaggregation/mooncake/conn.py ++++ b/python/sglang/srt/disaggregation/mooncake/conn.py +@@ -1074,6 +1074,19 @@ class MooncakeKVManager(CommonKVManager): + f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), {len(affected_rooms)} requests affected" + ) + ++ def close(self): ++ # Batch deregister KV data buffers ++ if self.kv_args.kv_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.kv_data_ptrs) ++ ++ # Batch deregister auxiliary data buffers ++ if self.kv_args.aux_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.aux_data_ptrs) ++ ++ # Batch deregister state/extra pool data buffers ++ if self.kv_args.state_data_ptrs: ++ self.engine.batch_deregister(self.kv_args.state_data_ptrs) ++ + + class MooncakeKVSender(CommonKVSender): + +diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py +index 952374ed5..239ac2571 100644 +--- a/python/sglang/srt/disaggregation/prefill.py ++++ b/python/sglang/srt/disaggregation/prefill.py +@@ -305,6 +305,13 @@ class PrefillBootstrapQueue: + else: + return bootstrapped_reqs, failed_reqs + ++ def release_memory_occupation(self): ++ if hasattr(self.kv_manager, "close"): ++ self.kv_manager.close() ++ ++ def resume_memory_occupation(self): ++ self.kv_manager = self._init_kv_manager() ++ + + class SchedulerDisaggregationPrefillMixin: + """ +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py +index 86c53f26b..52acf95b9 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl.py +@@ -380,3 +380,9 @@ class PyNcclCommunicator: + + self.disabled = old_disable + self.stream = old_stream ++ ++ def nccl_pause(self): ++ self.nccl.ncclPause(self.comm) ++ ++ def nccl_resume(self): ++ self.nccl.ncclResume(self.comm) +diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +index 6b12f2922..7028a4e46 100644 +--- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py ++++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +@@ -304,6 +304,17 @@ class NCCLLibrary: + Function("ncclGroupEnd", ncclResult_t, []), + ] + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ exported_functions.extend( ++ [ ++ # ncclResult_t ncclPause(ncclComm_t comm); ++ Function("ncclPause", ncclResult_t, [ncclComm_t]), ++ # ncclResult_t ncclResume(ncclComm_t comm); ++ Function("ncclResume", ncclResult_t, [ncclComm_t]), ++ Function("ncclSetGroupID", ncclResult_t, [ctypes.c_int]), ++ ] ++ ) ++ + exported_functions_symm_mem = [ + # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags); + Function( +@@ -551,6 +562,12 @@ class NCCLLibrary: + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + ++ def ncclPause(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclPause"](comm)) ++ ++ def ncclResume(self, comm: ncclComm_t) -> None: ++ self.NCCL_CHECK(self._funcs["ncclResume"](comm)) ++ + + __all__ = [ + "NCCLLibrary", +diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py +index cf90f6fe0..11d26df81 100644 +--- a/python/sglang/srt/distributed/parallel_state.py ++++ b/python/sglang/srt/distributed/parallel_state.py +@@ -1780,7 +1780,10 @@ def get_tensor_model_parallel_world_size(): + + def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" +- return get_tp_group().rank_in_group ++ try: ++ return get_tp_group().rank_in_group ++ except Exception: ++ return 0 + + + def get_pipeline_model_parallel_world_size(): +diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py +index 67a082ea6..390365864 100644 +--- a/python/sglang/srt/entrypoints/engine.py ++++ b/python/sglang/srt/entrypoints/engine.py +@@ -183,6 +183,7 @@ class Engine(EngineBase): + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + stream: bool = False, + bootstrap_host: Optional[Union[List[str], str]] = None, + bootstrap_port: Optional[Union[List[int], int]] = None, +@@ -218,6 +219,7 @@ class Engine(EngineBase): + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + return_hidden_states=return_hidden_states, ++ return_routed_experts=return_routed_experts, + stream=stream, + bootstrap_host=bootstrap_host, + bootstrap_port=bootstrap_port, +diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py +index 9f556a885..992843285 100644 +--- a/python/sglang/srt/layers/attention/vision.py ++++ b/python/sglang/srt/layers/attention/vision.py +@@ -518,11 +518,25 @@ class VisionAttention(nn.Module): + self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size + + if self.qk_normalization: ++ norm_kwargs = ( ++ dict( ++ weight_dtype=torch.float32, ++ cast_x_before_out_mul=True, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) + self.q_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + self.k_norm = RMSNorm( +- self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ++ self.dummy_dim, ++ eps=layer_norm_eps, ++ var_hidden_size=embed_dim, ++ **norm_kwargs, + ) + + # Select attention backend via a unified method +@@ -648,6 +662,15 @@ class VisionAttention(nn.Module): + if x.dim() == 2: + x = x.unsqueeze(0) + assert x.dim() == 3, x.shape ++ if ( ++ get_global_server_args().rl_on_policy_target is not None ++ and position_embeddings is not None ++ ): ++ assert isinstance(position_embeddings, tuple), ( ++ "expected position_embeddings to be a tuple of two tensors,\n" ++ f"but got {type(position_embeddings)}, change if needed" ++ ) ++ position_embeddings = tuple(p.to(x.dtype) for p in position_embeddings) + x_shape = x.shape + bsz, s, _ = x_shape + head = self.num_attention_heads_per_partition +diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py +index 932f52aeb..79c6b664f 100644 +--- a/python/sglang/srt/layers/communicator.py ++++ b/python/sglang/srt/layers/communicator.py +@@ -372,6 +372,7 @@ class LayerCommunicator: + residual: torch.Tensor, + forward_batch: ForwardBatch, + quant_format: str = "", ++ post_residual_addition: Optional[torch.Tensor] = None, + ): + if get_attn_tp_context().input_scattered: + hidden_states, residual = self._tp_reduce_scatter( +@@ -453,7 +454,9 @@ class LayerCommunicator: + ) + else: + hidden_states, residual = self.input_layernorm( +- hidden_states, residual ++ hidden_states, ++ residual, ++ post_residual_addition, + ) + + hidden_states = self._communicate_simple_fn( +diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py +index 3293a8a59..a075b71ce 100644 +--- a/python/sglang/srt/layers/layernorm.py ++++ b/python/sglang/srt/layers/layernorm.py +@@ -84,15 +84,12 @@ class RMSNorm(CustomOp): + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + cast_x_before_out_mul: bool = False, +- fp32_residual: bool = False, +- weight_dtype: Optional = None, +- override_orig_dtype: Optional = None, ++ fp32_residual: bool = True, + ) -> None: + super().__init__() + self.cast_x_before_out_mul = cast_x_before_out_mul + self.fp32_residual = fp32_residual +- self.override_orig_dtype = override_orig_dtype +- self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype)) ++ self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( +@@ -105,21 +102,26 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + if is_batch_invariant_mode_enabled(): + if ( + residual is not None + or get_global_server_args().rl_on_policy_target == "fsdp" + ): +- return self.forward_native(x, residual) ++ return self.forward_native(x, residual, post_residual_addition) + return rms_norm_batch_invariant( + x, + self.weight.data, + self.variance_epsilon, + ) + if residual is not None: ++ # TODO: Ideally we want to have (a+b)+c. but right now we can only have a+(b+c). ++ # (a+b)+c != a+(b+c), we probably need to add another parameter to fused_add_rmsnorm ++ if post_residual_addition is not None: ++ residual = residual + post_residual_addition + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x, residual + out = rmsnorm(x, self.weight.data, self.variance_epsilon) +@@ -179,17 +181,35 @@ class RMSNorm(CustomOp): + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() +- orig_dtype = self.override_orig_dtype or x.dtype ++ orig_dtype = x.dtype ++ ++ if residual is not None and not self.fp32_residual: ++ x = ( ++ x ++ + residual ++ + ( ++ post_residual_addition ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.clone() + x = x.to(torch.float32) +- if residual is not None: +- x = x + residual.to(torch.float32) +- if self.fp32_residual: +- residual = x.clone() +- else: +- residual = x.to(orig_dtype) ++ if residual is not None and self.fp32_residual: ++ x = ( ++ x ++ + residual.to(torch.float32) ++ + ( ++ post_residual_addition.to(torch.float32) ++ if post_residual_addition is not None ++ else 0.0 ++ ) ++ ) ++ residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: +diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py +index 522865765..733bad5f2 100644 +--- a/python/sglang/srt/layers/logits_processor.py ++++ b/python/sglang/srt/layers/logits_processor.py +@@ -841,11 +841,6 @@ class LogitsProcessor(nn.Module): + None, # bias + True, # is_vnni + ) +- elif get_global_server_args().rl_on_policy_target is not None: +- # Due to tie-weight, we may not be able to change lm_head's weight dtype +- logits = torch.matmul( +- hidden_states.bfloat16(), lm_head.weight.T.bfloat16() +- ) + else: + logits = torch.matmul( + hidden_states.to(lm_head.weight.dtype), lm_head.weight.T +diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +index e7d5a67cc..639e47163 100644 +--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py ++++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +@@ -14,6 +14,7 @@ import torch.nn.functional as F + import triton.language as tl + + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + direct_register_custom_op, +@@ -626,7 +627,10 @@ def fused_experts_impl( + ).squeeze(dim=1) + else: + # According to micro benchmark results, torch.compile can get better performance for small token. +- if tokens_in_chunk <= 32: ++ if ( ++ not get_global_server_args().enable_deterministic_inference ++ and tokens_in_chunk <= 32 ++ ): + moe_sum_reduce_torch_compile( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], +diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py +new file mode 100644 +index 000000000..e16817f1f +--- /dev/null ++++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py +@@ -0,0 +1,279 @@ ++import logging ++from abc import ABC ++from contextlib import contextmanager ++from typing import Optional ++ ++import numpy as np ++import torch ++ ++from sglang.srt.configs.model_config import ModelConfig ++from sglang.srt.layers.dp_attention import ( ++ get_attention_dp_rank, ++ get_dp_local_info, ++ is_dp_attention_enabled, ++) ++from sglang.srt.mem_cache.memory_pool import ReqToTokenPool ++from sglang.srt.server_args import get_global_server_args ++ ++logger = logging.getLogger(__name__) ++ ++_GB = 1024 * 1024 * 1024 ++_MB = 1024 * 1024 ++ ++ ++def get_tensor_size_bytes(t: torch.Tensor): ++ return np.prod(t.shape) * t.dtype.itemsize ++ ++ ++class _RoutedExpertsDeviceCache: ++ def __init__( ++ self, ++ max_running_requests: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ) -> None: ++ self.buffer = torch.zeros( ++ ( ++ max( ++ get_global_server_args().chunked_prefill_size ++ * get_global_server_args().dp_size, ++ max_running_requests, ++ ), ++ num_hidden_layers, ++ num_experts_per_tok + num_fused_shared_experts, ++ ), ++ dtype=torch.int32, ++ device=device, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): ++ assert layer_id is not None, "capturing routing experts but get layer_id None" ++ batch, _ = topk_ids.shape ++ self.buffer[:batch, layer_id, :] = topk_ids ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_MB = self.get_buffer_size_bytes() / _MB ++ logger.info( ++ f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" ++ ) ++ ++ ++class _RoutedExpertsHostCache: ++ def __init__( ++ self, ++ num_tokens: int, ++ num_hidden_layers: int, ++ num_experts_per_tok: int, ++ ) -> None: ++ self.num_tokens = num_tokens ++ self.buffer = torch.zeros( ++ ( ++ num_tokens, ++ num_hidden_layers, ++ num_experts_per_tok, ++ ), ++ dtype=torch.int32, ++ device="cpu", ++ pin_memory=True, ++ ) ++ self._finalize_allocation_log() ++ ++ def get_buffer_size_bytes(self): ++ assert hasattr(self, "buffer") ++ return get_tensor_size_bytes(self.buffer) ++ ++ def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): ++ self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) ++ ++ def _finalize_allocation_log(self): ++ """Common logging and memory usage computation for captured experts buffers.""" ++ buffer_size_GB = self.get_buffer_size_bytes() / _GB ++ logger.info( ++ f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" ++ ) ++ ++ ++class RoutedExpertsCapturer(ABC): ++ @staticmethod ++ def create( ++ enable: bool, ++ model_config: ModelConfig, ++ num_fused_shared_experts: int, ++ num_tokens: int, ++ max_running_requests: int, ++ device: str, ++ ): ++ if enable: ++ return _RoutedExpertsCapturerReal( ++ model_config, ++ num_tokens=num_tokens, ++ max_running_requests=max_running_requests, ++ num_fused_shared_experts=num_fused_shared_experts, ++ device=device, ++ ) ++ else: ++ return _RoutedExpertsCapturerNoop() ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ raise NotImplementedError ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ raise NotImplementedError ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ raise NotImplementedError ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ raise NotImplementedError ++ ++ def get_device_cache(self): ++ raise NotImplementedError ++ ++ ++class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): ++ """Capturer for routed experts with host buffer""" ++ ++ def __init__( ++ self, ++ model_config: ModelConfig, ++ num_tokens: int, ++ max_running_requests: int, ++ num_fused_shared_experts: int, ++ device: str, ++ ): ++ self.forward_batch = None ++ self.num_fused_shared_experts = num_fused_shared_experts ++ self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers ++ self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok ++ ++ self.host_cache = _RoutedExpertsHostCache( ++ num_tokens=num_tokens, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ ) ++ ++ self.device_cache = _RoutedExpertsDeviceCache( ++ max_running_requests=max_running_requests, ++ num_hidden_layers=self.num_hidden_layers, ++ num_experts_per_tok=self.num_experts_per_tok, ++ num_fused_shared_experts=self.num_fused_shared_experts, ++ device=device, ++ ) ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ if is_dp_attention_enabled(): ++ local_start_pos, local_num_tokens = get_dp_local_info(self.forward_batch) ++ # handle with cuda graph padding ++ if can_run_graph: ++ local_start_pos = get_attention_dp_rank() * cuda_graph_batch ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_end_pos = local_start_pos + local_num_tokens ++ else: ++ local_start_pos = 0 ++ local_end_pos = device_loc.shape[0] ++ ++ self.host_cache.buffer[cpu_loc] = self.device_cache.buffer[ ++ local_start_pos:local_end_pos, :, : self.num_experts_per_tok ++ ].cpu() ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ cache_pool_idx = ( ++ req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() ++ ) ++ return self.get_host_cache().buffer[cache_pool_idx] ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ self.forward_batch = forward_batch ++ yield ++ ++ def get_host_cache(self): ++ return self.host_cache ++ ++ def get_device_cache(self): ++ return self.device_cache ++ ++ ++class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): ++ def __init__(self): ++ pass ++ ++ def capture(self, layer_id: int, topk_ids: torch.Tensor): ++ pass ++ ++ def get_routed_experts( ++ self, ++ req_pool_idx: int, ++ seqlen: int, ++ req_to_token_pool: ReqToTokenPool, ++ ): ++ pass ++ ++ def sync_fwd_experts_buffer_DtoH( ++ self, ++ device_loc: torch.Tensor, ++ cpu_loc: torch.Tensor, ++ can_run_graph: bool, ++ cuda_graph_batch: int, ++ ): ++ pass ++ ++ @contextmanager ++ def with_forward(self, forward_batch): ++ yield ++ ++ def get_host_cache(self): ++ pass ++ ++ def get_device_cache(self): ++ pass ++ ++ ++_global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() ++ ++ ++def get_global_experts_capturer(): ++ return _global_expert_capturer ++ ++ ++def set_global_experts_capturer(capturer: RoutedExpertsCapturer): ++ global _global_expert_capturer ++ _global_expert_capturer = capturer +\ No newline at end of file +diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py +index a802647e8..0fd550c0c 100644 +--- a/python/sglang/srt/layers/moe/topk.py ++++ b/python/sglang/srt/layers/moe/topk.py +@@ -48,6 +48,7 @@ from sglang.srt.eplb.expert_location_dispatch import ( + ) + from sglang.srt.layers.dp_attention import is_allocation_symmetric + from sglang.srt.layers.moe import get_moe_runner_backend ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +@@ -212,6 +213,7 @@ class TopK(CustomOp): + self, + top_k: int, + *, ++ layer_id: Optional[int] = None, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, +@@ -233,6 +235,7 @@ class TopK(CustomOp): + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + ++ self.layer_id = layer_id + self.topk_config = TopKConfig( + top_k=top_k, + use_grouped_topk=use_grouped_topk, +@@ -260,6 +263,7 @@ class TopK(CustomOp): + self.topk_config.torch_native = True + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -309,6 +313,7 @@ class TopK(CustomOp): + ): + topk_output = select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -326,6 +331,7 @@ class TopK(CustomOp): + ) -> TopKOutput: + return select_experts( + hidden_states=hidden_states, ++ layer_id=self.layer_id, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, +@@ -856,6 +862,7 @@ def select_experts( + router_logits: torch.Tensor, + topk_config: TopKConfig, + *, ++ layer_id: Optional[int] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> StandardTopKOutput: +@@ -983,7 +990,10 @@ def select_experts( + ) + + get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) +- ++ get_global_experts_capturer().capture( ++ layer_id=layer_id, ++ topk_ids=topk_ids, ++ ) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + +diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py +index 70466bb20..cd85fc2f2 100644 +--- a/python/sglang/srt/layers/moe/utils.py ++++ b/python/sglang/srt/layers/moe/utils.py +@@ -284,7 +284,7 @@ def speculative_moe_a2a_backend_context(): + global MOE_A2A_BACKEND + original_backend = MOE_A2A_BACKEND + try: +- MOE_A2A_BACKEND = MoeA2ABackend.NONE ++ MOE_A2A_BACKEND = get_speculative_moe_a2a_backend() + yield + finally: + MOE_A2A_BACKEND = original_backend +diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py +index 0cdb7e1ae..df8860409 100644 +--- a/python/sglang/srt/layers/rotary_embedding.py ++++ b/python/sglang/srt/layers/rotary_embedding.py +@@ -15,7 +15,6 @@ from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, +- get_compiler_backend, + is_cpu, + is_cuda, + is_hip, +@@ -132,9 +131,7 @@ class RotaryEmbedding(CustomOp): + + if get_global_server_args().rl_on_policy_target is not None: + self._forward_method = self.forward_native +- self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)( +- self._apply_rotary_emb_wrapped +- ) ++ + self.position_cos, self.position_sin = None, None + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: +@@ -1423,6 +1420,9 @@ class MRotaryEmbedding(RotaryEmbedding): + f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" + ) + ++ if get_global_server_args().rl_on_policy_target is not None: ++ self._forward_method = self.forward_native ++ + def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: + # __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`) + # is expensive, so avoid calling it if possible +@@ -1432,8 +1432,7 @@ class MRotaryEmbedding(RotaryEmbedding): + ): + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + +- @torch.compile(dynamic=True, backend=get_compiler_backend()) +- def _forward_native( ++ def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1490,7 +1489,7 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def forward( ++ def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1507,14 +1506,12 @@ class MRotaryEmbedding(RotaryEmbedding): + """ + assert positions.ndim == 1 or positions.ndim == 2 + +- if positions.ndim == 2 and self.mrope_section and _is_cuda: +- return self._forward_triton(positions, query, key) +- elif _is_npu: +- return self._forward_npu(positions, query, key) +- else: +- return self._forward_native(positions, query, key) ++ # Use Triton kernel for multimodal (2D positions) with mrope ++ if positions.ndim == 2 and self.mrope_section: ++ return self.forward_triton(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + +- def _forward_triton( ++ def forward_triton( + self, + positions: torch.Tensor, + query: torch.Tensor, +@@ -1563,15 +1560,19 @@ class MRotaryEmbedding(RotaryEmbedding): + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + +- def _forward_npu( ++ def forward_npu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, ++ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: ++ assert ( ++ fused_set_kv_buffer_arg is None ++ ), "fused_set_kv_buffer_arg is not supported for npu implementation" + # TODO: remove this when npu_mrope supports QNumHeads * QHeadSize > 4096 + if query.shape[1] > 4096: +- return self._forward_native(positions, query, key) ++ return self.forward_native(positions, query, key, fused_set_kv_buffer_arg) + rotary_mode = "half" + if self.is_neox_style: + rotary_mode = "half" +diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py +index 7f6f6a010..c4a673145 100644 +--- a/python/sglang/srt/layers/sampler.py ++++ b/python/sglang/srt/layers/sampler.py +@@ -105,16 +105,11 @@ class Sampler(nn.Module): + if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB: + probs_without_temp_scaling = torch.softmax(logits, dim=-1) + +- if get_global_server_args().rl_on_policy_target is not None: +- logits_div_temperature = ( +- logits.bfloat16().div(sampling_info.temperatures).bfloat16() +- ) +- logprobs_via_logsoftmax_kernel = torch.log_softmax( +- logits_div_temperature, dim=-1 +- ) +- + # Post process logits + logits.div_(sampling_info.temperatures) ++ if get_global_server_args().rl_on_policy_target is not None: ++ logprobs_via_logsoftmax_kernel = torch.log_softmax(logits, dim=-1) ++ + # For ascend backend, softmax is not needed before sampling + if not get_global_server_args().sampling_backend == "ascend" or ( + return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB +diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py +index 87922077e..8cb6bad8d 100644 +--- a/python/sglang/srt/managers/detokenizer_manager.py ++++ b/python/sglang/srt/managers/detokenizer_manager.py +@@ -247,6 +247,16 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + s.sent_offset = len(output_str) + output_strs.append(incremental_output) + ++ output_routed_experts = [] ++ if recv_obj.output_routed_experts is not None: ++ output_routed_experts = [ ++ ( ++ output_routed_experts.tolist() ++ if output_routed_experts is not None ++ else [] ++ ) ++ for output_routed_experts in recv_obj.output_routed_experts ++ ] + return BatchStrOutput( + rids=recv_obj.rids, + http_worker_ipcs=recv_obj.http_worker_ipcs, +@@ -272,6 +282,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): + output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, + output_token_entropy_val=recv_obj.output_token_entropy_val, + output_hidden_states=recv_obj.output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=recv_obj.retraction_counts, +diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py +index e34736cc4..5e5997a1a 100644 +--- a/python/sglang/srt/managers/io_struct.py ++++ b/python/sglang/srt/managers/io_struct.py +@@ -23,6 +23,8 @@ from dataclasses import dataclass, field + from enum import Enum + from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union + ++import torch ++ + from sglang.srt.lora.lora_registry import LoRARef + from sglang.srt.managers.schedule_batch import BaseFinishReason + from sglang.srt.multimodal.mm_utils import has_valid_data +@@ -175,6 +177,8 @@ class GenerateReqInput(BaseReq): + log_metrics: bool = True + # Whether to return hidden states + return_hidden_states: Union[List[bool], bool] = False ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None +@@ -592,6 +596,7 @@ class GenerateReqInput(BaseReq): + if isinstance(self.return_hidden_states, list) + else self.return_hidden_states + ), ++ return_routed_experts=self.return_routed_experts, + modalities=self.modalities[i] if self.modalities else None, + session_params=self.session_params, + lora_path=self.lora_path[i] if self.lora_path is not None else None, +@@ -655,6 +660,9 @@ class TokenizedGenerateReqInput(BaseReq): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured routed experts ++ return_routed_experts: bool = False ++ + # The input embeds + input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + +@@ -910,6 +918,9 @@ class BatchTokenIDOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[torch.Tensor] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +@@ -989,6 +1000,9 @@ class BatchStrOutput( + # Hidden states + output_hidden_states: List[List[float]] + ++ # The routed experts for each output token ++ output_routed_experts: List[List[int]] ++ + # The information of placeholder tokens (e.g., image token) + # idx is the index of the token in the prompt after expansion. + # val is the length of padded tokens after expansion. +diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py +index c4c5a9ebb..1450c5fd8 100644 +--- a/python/sglang/srt/managers/schedule_batch.py ++++ b/python/sglang/srt/managers/schedule_batch.py +@@ -450,6 +450,7 @@ class Req: + session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + return_hidden_states: bool = False, ++ return_routed_experts: bool = False, + eos_token_ids: Optional[Set[int]] = None, + bootstrap_host: Optional[str] = None, + bootstrap_port: Optional[int] = None, +@@ -629,6 +630,12 @@ class Req: + self.output_topk_p = None + self.output_topk_index = None + ++ # capture routed experts ++ self.return_routed_experts = return_routed_experts ++ self.routed_experts: Optional[torch.Tensor] = ( ++ None # cpu tensor: shape (seqlen, topk) ++ ) ++ + # Embedding (return values) + self.embedding = None + +@@ -992,6 +999,7 @@ class Req: + self.retraction_count += 1 + + self.prefix_indices = torch.empty((0,), dtype=torch.int64) ++ self.routed_experts = [] + self.last_node = None + self.swa_uuid_for_lock = None + self.extend_input_len = 0 +@@ -1159,6 +1167,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + # Whether to return hidden states + return_hidden_states: bool = False + ++ # Whether to return captured experts ++ return_routed_experts: bool = False ++ + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + +@@ -1206,6 +1217,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + device=req_to_token_pool.device, + spec_algorithm=spec_algorithm, + return_hidden_states=any(req.return_hidden_states for req in reqs), ++ return_routed_experts=any(req.return_routed_experts for req in reqs), + is_prefill_only=all(req.is_prefill_only for req in reqs), + chunked_req=chunked_req, + dllm_config=dllm_config, +@@ -1457,6 +1469,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.req_pool_indices = req_pool_indices_tensor + self.orig_seq_lens = orig_seq_lens_tensor + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc.cpu() + self.input_embeds = ( + torch.tensor(input_embeds).to(self.device, non_blocking=True) + if input_embeds +@@ -1508,10 +1521,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + input_ids = torch.cat([self.input_ids, running_batch.input_ids]) + out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc]) ++ out_cache_loc_cpu = torch.cat( ++ [self.out_cache_loc_cpu, running_batch.out_cache_loc_cpu] ++ ) + + self.merge_batch(running_batch) + self.input_ids = input_ids + self.out_cache_loc = out_cache_loc ++ self.out_cache_loc_cpu = out_cache_loc_cpu + + # For overlap scheduler, the output_ids has one step delay + delta = 0 if self.enable_overlap else -1 +@@ -1677,6 +1694,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.empty(0, dtype=torch.int64) + self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device) ++ self.out_cache_loc_cpu = torch.empty(0, dtype=torch.int64, device="cpu") + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens_sum = 0 + self.extend_num_tokens = 0 +@@ -1736,6 +1754,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + + # Allocate memory + self.out_cache_loc = alloc_for_decode(self, token_per_req=1) ++ self.out_cache_loc_cpu = self.out_cache_loc.to("cpu", non_blocking=True) + + # Update req-level memory management fields + for req in self.reqs: +@@ -1807,6 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] + self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum = self.seq_lens.sum().item() + self.output_ids = self.output_ids[keep_indices_device] + self.return_logprob = any(req.return_logprob for req in self.reqs) +@@ -1852,6 +1872,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) + self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) + self.out_cache_loc = None ++ self.out_cache_loc_cpu = None + self.seq_lens_sum += other.seq_lens_sum + if self.output_ids is not None: + self.output_ids = torch.cat([self.output_ids, other.output_ids]) +@@ -1903,6 +1924,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + seq_lens=self.seq_lens, + orig_seq_lens=self.orig_seq_lens, + out_cache_loc=self.out_cache_loc, ++ out_cache_loc_cpu=self.out_cache_loc_cpu, + seq_lens_cpu=seq_lens_cpu, + seq_lens_sum=self.seq_lens_sum, + return_logprob=self.return_logprob, +@@ -1983,7 +2005,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " +- f"#req={(len(self.reqs))})" ++ f"#req={(len(self.reqs))}), " ++ f"#out_cache_loc={self.out_cache_loc})" + ) + + +@@ -2038,6 +2061,9 @@ class ModelWorkerBatch: + # Sampling info + sampling_info: SamplingBatchInfo + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence lengths, Qwen-1M related + orig_seq_lens: Optional[torch.Tensor] = None + +diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py +index b801fd8f8..9e27cc825 100644 +--- a/python/sglang/srt/managers/scheduler.py ++++ b/python/sglang/srt/managers/scheduler.py +@@ -1305,6 +1305,7 @@ class Scheduler( + input_embeds=recv_req.input_embeds, + custom_logit_processor=recv_req.custom_logit_processor, + return_hidden_states=recv_req.return_hidden_states, ++ return_routed_experts=recv_req.return_routed_experts, + eos_token_ids=self.model_config.hf_eos_token_id, + bootstrap_host=recv_req.bootstrap_host, + bootstrap_port=recv_req.bootstrap_port, +diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +index c48f5f893..a9796c25f 100644 +--- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py ++++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py +@@ -9,6 +9,7 @@ import torch + from sglang.srt.disaggregation.utils import DisaggregationMode + from sglang.srt.environ import envs + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer + from sglang.srt.managers.io_struct import ( + AbortReq, + BatchEmbeddingOutput, +@@ -112,6 +113,14 @@ class SchedulerOutputProcessorMixin: + req.check_finished() + + if req.finished(): ++ req.routed_experts = ( ++ get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ ) ++ + release_kv_cache(req, self.tree_cache) + req.time_stats.completion_time = time.perf_counter() + elif not batch.decoding_reqs or req not in batch.decoding_reqs: +@@ -362,6 +371,12 @@ class SchedulerOutputProcessorMixin: + req.check_finished(new_accepted_len) + + if req.finished(): ++ req.routed_experts = get_global_experts_capturer().get_routed_experts( ++ req_pool_idx=req.req_pool_idx, ++ seqlen=req.seqlen, ++ req_to_token_pool=self.req_to_token_pool, ++ ) ++ + if self.server_args.disaggregation_decode_enable_offload_kvcache: + # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes + if not self.decode_offload_manager.offload_kv_cache(req): +@@ -756,6 +771,7 @@ class SchedulerOutputProcessorMixin: + spec_accepted_tokens = [] + retraction_counts = [] + output_hidden_states = None ++ output_routed_experts = None + + queue_times = [] + forward_entry_times = [] +@@ -946,6 +962,10 @@ class SchedulerOutputProcessorMixin: + if output_hidden_states is None: + output_hidden_states = [] + output_hidden_states.append(req.hidden_states) ++ if req.return_routed_experts: ++ if output_routed_experts is None: ++ output_routed_experts = [] ++ output_routed_experts.append(req.routed_experts) + + if ( + req.finished() +@@ -994,6 +1014,7 @@ class SchedulerOutputProcessorMixin: + output_token_ids_logprobs_idx=output_token_ids_logprobs_idx, + output_token_entropy_val=None, + output_hidden_states=output_hidden_states, ++ output_routed_experts=output_routed_experts, + placeholder_tokens_idx=None, + placeholder_tokens_val=None, + retraction_counts=retraction_counts, +diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +index f8ebfc1f4..a05449fac 100644 +--- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py ++++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py +@@ -1,6 +1,7 @@ + from __future__ import annotations + + import logging ++import os + import traceback + from typing import TYPE_CHECKING, Tuple + +@@ -12,6 +13,9 @@ from sglang.srt.constants import ( + GPU_MEMORY_TYPE_KV_CACHE, + GPU_MEMORY_TYPE_WEIGHTS, + ) ++from sglang.srt.disaggregation.utils import DisaggregationMode ++from sglang.srt.distributed import get_moe_ep_group, get_moe_tp_group, get_tp_group ++from sglang.srt.layers.dp_attention import get_attention_tp_group + from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, +@@ -127,6 +131,13 @@ class SchedulerUpdateWeightsMixin: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.release_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.release_memory_occupation() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model +@@ -137,6 +148,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_pause() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_pause() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_pause() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_pause() ++ + torch.get_device_module().synchronize() + + return ReleaseMemoryOccupationReqOutput() +@@ -155,6 +180,20 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_CUDA_GRAPH in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH) + ++ if os.environ.get("AMEM_ENABLE", "0") == "1": ++ tp_group = get_tp_group() ++ if tp_group is not None and tp_group.pynccl_comm is not None: ++ tp_group.pynccl_comm.nccl_resume() ++ attn_tp_group = get_attention_tp_group() ++ if attn_tp_group is not None and attn_tp_group.pynccl_comm is not None: ++ attn_tp_group.pynccl_comm.nccl_resume() ++ moe_ep_group = get_moe_ep_group() ++ if moe_ep_group is not None and moe_ep_group.pynccl_comm is not None: ++ moe_ep_group.pynccl_comm.nccl_resume() ++ moe_tp_group = get_moe_tp_group() ++ if moe_tp_group is not None and moe_tp_group.pynccl_comm is not None: ++ moe_tp_group.pynccl_comm.nccl_resume() ++ + if GPU_MEMORY_TYPE_WEIGHTS in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) + torch.distributed.barrier(self.tp_cpu_group) +@@ -167,6 +206,13 @@ class SchedulerUpdateWeightsMixin: + if GPU_MEMORY_TYPE_KV_CACHE in tags: + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + ++ if self.disaggregation_mode == DisaggregationMode.DECODE: ++ if hasattr(self, "disagg_decode_prealloc_queue"): ++ self.disagg_decode_prealloc_queue.resume_memory_occupation() ++ elif self.disaggregation_mode == DisaggregationMode.PREFILL: ++ if hasattr(self, "disagg_prefill_bootstrap_queue"): ++ self.disagg_prefill_bootstrap_queue.resume_memory_occupation() ++ + return ResumeMemoryOccupationReqOutput() + + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): +diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py +index b90cf0616..98d71d896 100644 +--- a/python/sglang/srt/managers/tokenizer_manager.py ++++ b/python/sglang/srt/managers/tokenizer_manager.py +@@ -888,6 +888,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): + session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, + return_hidden_states=obj.return_hidden_states, ++ return_routed_experts=obj.return_routed_experts, + data_parallel_rank=obj.data_parallel_rank, + priority=obj.priority, + extra_key=obj.extra_key, +@@ -1621,6 +1622,9 @@ class TokenizerManager(TokenizerCommunicatorMixin): + if getattr(recv_obj, "output_hidden_states", None): + meta_info["hidden_states"] = recv_obj.output_hidden_states[i] + ++ if getattr(recv_obj, "output_routed_experts", None): ++ meta_info["routed_experts"] = recv_obj.output_routed_experts[i] ++ + if isinstance(recv_obj, BatchStrOutput): + state.text += recv_obj.output_strs[i] + if self.server_args.stream_output and state.obj.stream: +@@ -1747,12 +1751,13 @@ class TokenizerManager(TokenizerCommunicatorMixin): + return + + if len(recv_obj.input_token_logprobs_val) > 0: +- state.input_token_logprobs_val.extend( +- recv_obj.input_token_logprobs_val[recv_obj_index] +- ) +- state.input_token_logprobs_idx.extend( +- recv_obj.input_token_logprobs_idx[recv_obj_index] +- ) ++ if recv_obj.input_token_logprobs_val[recv_obj_index]: ++ state.input_token_logprobs_val.extend( ++ recv_obj.input_token_logprobs_val[recv_obj_index] ++ ) ++ state.input_token_logprobs_idx.extend( ++ recv_obj.input_token_logprobs_idx[recv_obj_index] ++ ) + state.output_token_logprobs_val.extend( + recv_obj.output_token_logprobs_val[recv_obj_index] + ) +diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py +index 3a85e6a7e..2859dafa1 100644 +--- a/python/sglang/srt/model_executor/forward_batch_info.py ++++ b/python/sglang/srt/model_executor/forward_batch_info.py +@@ -51,6 +51,7 @@ from sglang.srt.layers.dp_attention import ( + set_dp_buffer_len, + set_is_extend_in_batch, + ) ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import get_compiler_backend, is_npu, support_triton + from sglang.srt.utils.common import ceil_align + +@@ -214,6 +215,9 @@ class ForwardBatch: + # The sum of all sequence lengths + seq_lens_sum: int + ++ # cpu copy of out_cache_loc ++ out_cache_loc_cpu: Optional[torch.Tensor] = None ++ + # The original sequence length without being chunked. Qwen-1M related. + orig_seq_lens: Optional[torch.Tensor] = None + +@@ -368,6 +372,7 @@ class ForwardBatch: + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + out_cache_loc=batch.out_cache_loc, ++ out_cache_loc_cpu=batch.out_cache_loc_cpu, + mm_inputs=batch.multimodal_inputs, + encoder_cached=batch.encoder_cached, + encoder_lens=batch.encoder_lens, +@@ -623,7 +628,10 @@ class ForwardBatch: + mm_input = batch.multimodal_inputs[batch_idx] + if self.forward_mode.is_decode(): + # 3 * N +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + mrope_positions_list[batch_idx] = torch.full( + (3, 1), + self.seq_lens[batch_idx] - 1, +@@ -640,7 +648,10 @@ class ForwardBatch: + batch.extend_seq_lens[batch_idx], + batch.extend_prefix_lens[batch_idx], + ) +- if mm_input is None: ++ if ( ++ mm_input is None ++ or get_global_server_args().rl_on_policy_target is not None ++ ): + # text only + mrope_positions = torch.tensor( + [ +@@ -823,6 +834,10 @@ class ForwardBatch: + ) + + self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens) ++ if self.out_cache_loc_cpu is not None: ++ self.out_cache_loc_cpu = self._pad_tensor_to_size( ++ self.out_cache_loc_cpu, num_tokens ++ ) + if self.encoder_lens is not None: + self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs) + self.positions = self._pad_tensor_to_size(self.positions, num_tokens) +diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py +index 4d58278b7..8f50dc430 100644 +--- a/python/sglang/srt/model_executor/model_runner.py ++++ b/python/sglang/srt/model_executor/model_runner.py +@@ -94,6 +94,11 @@ from sglang.srt.layers.dp_attention import ( + set_is_extend_in_batch, + ) + from sglang.srt.layers.logits_processor import LogitsProcessorOutput ++from sglang.srt.layers.moe.routed_experts_capturer import ( ++ RoutedExpertsCapturer, ++ get_global_experts_capturer, ++ set_global_experts_capturer, ++) + from sglang.srt.layers.pooler import EmbeddingPoolerOutput + from sglang.srt.layers.sampler import Sampler + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model +@@ -502,6 +507,10 @@ class ModelRunner: + server_args.max_running_requests, + server_args.max_total_tokens, + ) ++ ++ # Init routed experts capturer ++ self.init_routed_experts_capturer() ++ + if self.device == "cuda": + self.init_cublas() + self.init_attention_backend() +@@ -545,6 +554,40 @@ class ModelRunner: + # Initialize piecewise CUDA graph + self.init_piecewise_cuda_graphs() + ++ def init_routed_experts_capturer(self): ++ # TODO: the redundant logic with TpModelWorker ++ max_running_requests = min( ++ ( ++ self.max_total_num_tokens // 2 ++ if self.server_args.max_running_requests is None ++ else self.server_args.max_running_requests ++ // ( ++ self.server_args.dp_size ++ if self.server_args.enable_dp_attention ++ else 1 ++ ) ++ ), ++ self.req_to_token_pool.size, ++ ) ++ ++ if not self.server_args.disable_shared_experts_fusion and hasattr( ++ self.model, "num_fused_shared_experts" ++ ): ++ num_fused_shared_experts = self.model.num_fused_shared_experts ++ else: ++ num_fused_shared_experts = 0 ++ ++ set_global_experts_capturer( ++ RoutedExpertsCapturer.create( ++ enable=get_global_server_args().enable_return_routed_experts, ++ model_config=self.model_config, ++ num_fused_shared_experts=num_fused_shared_experts, ++ num_tokens=self.max_total_num_tokens + self.page_size, ++ max_running_requests=max_running_requests, ++ device=self.device, ++ ) ++ ) ++ + def model_specific_adjustment(self): + server_args = self.server_args + +@@ -792,7 +835,11 @@ class ModelRunner: + ) + with self.memory_saver_adapter.region( + GPU_MEMORY_TYPE_WEIGHTS, +- enable_cpu_backup=enable_cpu_backup, ++ enable_cpu_backup=( ++ self.server_args.enable_weights_cpu_backup ++ if not self.is_draft_worker ++ else True ++ ), + ): + self.model = get_model( + model_config=self.model_config, +@@ -2645,9 +2692,12 @@ class ModelRunner: + ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: + self.forward_pass_id += 1 + +- with get_global_expert_distribution_recorder().with_forward_pass( +- self.forward_pass_id, +- forward_batch, ++ with ( ++ get_global_expert_distribution_recorder().with_forward_pass( ++ self.forward_pass_id, ++ forward_batch, ++ ), ++ get_global_experts_capturer().with_forward(forward_batch), + ): + output = self._forward_raw( + forward_batch, +@@ -2656,6 +2706,13 @@ class ModelRunner: + reinit_attn_backend, + split_forward_count, + ) ++ # Copy cached routing experts' buffers back to CPU cache ++ get_global_experts_capturer().sync_fwd_experts_buffer_DtoH( ++ device_loc=forward_batch.out_cache_loc, ++ cpu_loc=forward_batch.out_cache_loc_cpu, ++ can_run_graph=output[1], ++ cuda_graph_batch=getattr(self.graph_runner, "bs", None), ++ ) + + if self.eplb_manager is not None: + self.eplb_manager.on_forward_pass_end() +diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py +index dc30b4f0a..f29dc4b71 100644 +--- a/python/sglang/srt/models/deepseek_v2.py ++++ b/python/sglang/srt/models/deepseek_v2.py +@@ -667,6 +667,7 @@ class DeepseekV2MoE(nn.Module): + + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, ++ layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, +diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py +index ab1b6576b..dffd8f09a 100644 +--- a/python/sglang/srt/models/ernie4.py ++++ b/python/sglang/srt/models/ernie4.py +@@ -87,6 +87,7 @@ class Ernie4Moe(nn.Module): + + self.topk = TopK( + top_k=config.moe_k, ++ layer_id=layer_id, + renormalize=True, + use_grouped_topk=False, + correction_bias=self.gate.e_score_correction_bias, +diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py +index a9689b8f2..bc8538da8 100644 +--- a/python/sglang/srt/models/glm4_moe.py ++++ b/python/sglang/srt/models/glm4_moe.py +@@ -379,6 +379,17 @@ class Glm4MoeSparseMoeBlock(nn.Module): + + self.gate = Glm4MoeGate(config=config, prefix=add_prefix("gate", prefix)) + ++ self.topk = TopK( ++ top_k=self.top_k, ++ layer_id=self.layer_id, ++ renormalize=config.norm_topk_prob, ++ use_grouped_topk=True, ++ num_expert_group=config.n_group, ++ topk_group=config.topk_group, ++ correction_bias=self.gate.e_score_correction_bias, ++ routed_scaling_factor=self.routed_scaling_factor, ++ ) ++ + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.n_routed_experts + self.num_fused_shared_experts, + num_fused_shared_experts=self.num_fused_shared_experts, +diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py +index 9474700c4..398d622ff 100644 +--- a/python/sglang/srt/models/gpt_oss.py ++++ b/python/sglang/srt/models/gpt_oss.py +@@ -113,6 +113,7 @@ class GptOssSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, ++ layer_id=layer_id, + ) + + self.top_k = config.num_experts_per_tok +diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py +index fd513060a..a089475b7 100644 +--- a/python/sglang/srt/models/grok.py ++++ b/python/sglang/srt/models/grok.py +@@ -142,6 +142,7 @@ class Grok1MoE(nn.Module): + self.topk = TopK( + top_k=top_k, + renormalize=False, ++ layer_id=layer_id, + custom_routing_function=custom_routing_function, + ) + +diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py +index 7c6fd9e48..b20d28544 100644 +--- a/python/sglang/srt/models/hunyuan.py ++++ b/python/sglang/srt/models/hunyuan.py +@@ -150,6 +150,7 @@ class HunYuanSparseMoeBlock(nn.Module): + + self.topk = TopK( + top_k=top_k, ++ layer_id=layer_id, + renormalize=True if top_k > 1 else False, + ) + +diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py +index 3530609ba..01c89e893 100644 +--- a/python/sglang/srt/models/longcat_flash.py ++++ b/python/sglang/srt/models/longcat_flash.py +@@ -245,6 +245,7 @@ class LongcatFlashMoE(nn.Module): + renormalize=False, + use_grouped_topk=False, + correction_bias=self.router.e_score_correction_bias.data, ++ layer_id=layer_id, + ) + self.topk.forward = self.topk.forward_native + +diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py +index a7dbadec6..c83a41338 100644 +--- a/python/sglang/srt/models/qwen2.py ++++ b/python/sglang/srt/models/qwen2.py +@@ -90,9 +90,6 @@ class Qwen2MLP(nn.Module): + self.act_fn = SiluAndMul() + + def forward(self, x): +- if get_global_server_args().rl_on_policy_target is not None: +- x = x.bfloat16() +- + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) +@@ -279,11 +276,6 @@ class Qwen2Model(nn.Module): + quant_config=quant_config, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), +- params_dtype=( +- torch.float32 +- if get_global_server_args().rl_on_policy_target is not None +- else None +- ), + ) + else: + self.embed_tokens = PPMissingLayer() +@@ -306,10 +298,8 @@ class Qwen2Model(nn.Module): + if self.pp_group.is_last_rank: + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py +index ea33e81ef..561934dce 100644 +--- a/python/sglang/srt/models/qwen2_moe.py ++++ b/python/sglang/srt/models/qwen2_moe.py +@@ -161,6 +161,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +@@ -581,7 +582,17 @@ class Qwen2MoeModel(nn.Module): + prefix=add_prefix("layers", prefix), + ) + if self.pp_group.is_last_rank: +- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.norm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + else: + self.norm = PPMissingLayer(return_tuple=True) + +diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py +index 30b92acbd..a0d14895f 100644 +--- a/python/sglang/srt/models/qwen3.py ++++ b/python/sglang/srt/models/qwen3.py +@@ -90,8 +90,8 @@ class Qwen3Attention(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -256,10 +256,8 @@ class Qwen3DecoderLayer(nn.Module): + + norm_kwargs = ( + dict( +- weight_dtype=torch.float32, + cast_x_before_out_mul=True, +- override_orig_dtype=torch.float32, +- fp32_residual=True, ++ fp32_residual=False, + ) + if get_global_server_args().rl_on_policy_target is not None + else {} +@@ -289,10 +287,14 @@ class Qwen3DecoderLayer(nn.Module): + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], ++ post_residual_addition: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + hidden_states, residual = self.layer_communicator.prepare_attn( +- hidden_states, residual, forward_batch ++ hidden_states, ++ residual, ++ forward_batch, ++ post_residual_addition=post_residual_addition, + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( +diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py +index 9737ac719..09c756918 100644 +--- a/python/sglang/srt/models/qwen3_moe.py ++++ b/python/sglang/srt/models/qwen3_moe.py +@@ -22,6 +22,7 @@ import math + from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar + + import torch ++import torch.nn.functional as F + from torch import nn + from transformers import PretrainedConfig + +@@ -50,7 +51,7 @@ from sglang.srt.layers.moe import ( + ) + from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +-from sglang.srt.layers.moe.topk import TopK ++from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK + from sglang.srt.layers.moe.utils import RoutingMethodType + from sglang.srt.layers.quantization.base_config import QuantizationConfig + from sglang.srt.layers.radix_attention import RadixAttention +@@ -227,7 +228,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + use_grouped_topk=False, ++ layer_id=layer_id, + ) ++ self.top_k = config.num_experts_per_tok + + self.experts = get_moe_impl_class(quant_config)( + num_experts=config.num_experts +@@ -293,7 +296,22 @@ class Qwen3MoeSparseMoeBlock(nn.Module): + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) +- topk_output = self.topk(hidden_states, router_logits) ++ ++ if get_global_server_args().rl_on_policy_target is not None: ++ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) ++ routing_weights, selected_experts = torch.topk( ++ routing_weights, self.top_k, dim=-1 ++ ) ++ routing_weights /= routing_weights.sum(dim=-1, keepdim=True) ++ routing_weights = routing_weights.to(hidden_states.dtype) ++ topk_output = StandardTopKOutput( ++ topk_weights=routing_weights, ++ topk_ids=selected_experts, ++ router_logits=router_logits, ++ ) ++ else: ++ topk_output = self.topk(hidden_states, router_logits) ++ + final_hidden_states = self.experts(hidden_states, topk_output) + if ( + self.tp_size > 1 +@@ -474,13 +492,14 @@ class Qwen3MoeAttention(nn.Module): + ) + self.compatible_with_fused_kv_buffer = ( + False if isinstance(self.rotary_emb, MRotaryEmbedding) else True +- ) ++ ) and (get_global_server_args().rl_on_policy_target is None) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope ++ and (get_global_server_args().rl_on_policy_target is None) + ) + self._used_fused_qk_norm_rope_last_call = False + +@@ -493,8 +512,16 @@ class Qwen3MoeAttention(nn.Module): + prefix=add_prefix("attn", prefix), + ) + +- self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) +- self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) ++ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs) + self.alt_stream = alt_stream + + def _apply_qk_norm( +@@ -751,9 +778,19 @@ class Qwen3MoeDecoderLayer(nn.Module): + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) +- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ norm_kwargs = ( ++ dict( ++ cast_x_before_out_mul=True, ++ fp32_residual=False, ++ ) ++ if get_global_server_args().rl_on_policy_target is not None ++ else {} ++ ) ++ self.input_layernorm = RMSNorm( ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs ++ ) + self.post_attention_layernorm = RMSNorm( +- config.hidden_size, eps=config.rms_norm_eps ++ config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs + ) + + self.layer_communicator = LayerCommunicator( +diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py +index ed52f7ff4..8ce9fab9d 100644 +--- a/python/sglang/srt/models/qwen3_vl.py ++++ b/python/sglang/srt/models/qwen3_vl.py +@@ -18,7 +18,6 @@ import re + from functools import lru_cache, partial + from typing import Callable, Iterable, List, Optional, Tuple, Union + +-import numpy as np + import torch + import torch.nn as nn + from einops import rearrange +@@ -349,83 +348,65 @@ class Qwen3VLMoeVisionModel(nn.Module, RotaryPosMixin): + return rotary_pos_emb + + def fast_pos_embed_interpolate(self, grid_thw): ++ grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + num_grid_per_side = int(self.num_position_embeddings**0.5) ++ device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + +- # TODO: use torch instand of np +- for t, h, w in grid_thw: +- h_idxs = np.linspace(0, num_grid_per_side - 1, h) +- w_idxs = np.linspace(0, num_grid_per_side - 1, w) ++ for t, h, w in zip(grid_ts, grid_hs, grid_ws): ++ h_idxs = torch.linspace(0, num_grid_per_side - 1, h) ++ w_idxs = torch.linspace(0, num_grid_per_side - 1, w) + +- h_idxs_floor = h_idxs.astype(int) +- w_idxs_floor = w_idxs.astype(int) +- h_idxs_ceil = (h_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) +- w_idxs_ceil = (w_idxs.astype(int) + 1).clip(max=num_grid_per_side - 1) ++ h_idxs_floor = h_idxs.int() ++ w_idxs_floor = w_idxs.int() ++ h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) ++ w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + +- idx_list[0].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[1].extend( +- ((h_idxs_floor * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[2].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_floor[None]) +- .flatten() +- .tolist() +- * t +- ) +- idx_list[3].extend( +- ((h_idxs_ceil * num_grid_per_side)[None].T + w_idxs_ceil[None]) +- .flatten() +- .tolist() +- * t +- ) ++ base_h = h_idxs_floor * num_grid_per_side ++ base_h_ceil = h_idxs_ceil * num_grid_per_side + +- weight_list[0].extend( +- ((1 - dh)[None].T * (1 - dw)[None]).flatten().tolist() * t +- ) +- weight_list[1].extend(((1 - dh)[None].T * dw[None]).flatten().tolist() * t) +- weight_list[2].extend((dh[None].T * (1 - dw)[None]).flatten().tolist() * t) +- weight_list[3].extend((dh[None].T * dw[None]).flatten().tolist() * t) ++ indices = [ ++ (base_h[None].T + w_idxs_floor[None]).flatten(), ++ (base_h[None].T + w_idxs_ceil[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), ++ (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), ++ ] + +- device = self.pos_embed.weight.device +- dtype = self.pos_embed.weight.dtype ++ weights = [ ++ ((1 - dh)[None].T * (1 - dw)[None]).flatten(), ++ ((1 - dh)[None].T * dw[None]).flatten(), ++ (dh[None].T * (1 - dw)[None]).flatten(), ++ (dh[None].T * dw[None]).flatten(), ++ ] + +- p0 = ( +- self.pos_embed(torch.tensor(idx_list[0], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[0], dtype=dtype, device=device)[:, None] +- ) +- p1 = ( +- self.pos_embed(torch.tensor(idx_list[1], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[1], dtype=dtype, device=device)[:, None] +- ) +- p2 = ( +- self.pos_embed(torch.tensor(idx_list[2], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[2], dtype=dtype, device=device)[:, None] ++ for i in range(4): ++ idx_list[i].extend(indices[i].tolist()) ++ weight_list[i].extend(weights[i].tolist()) ++ ++ idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) ++ weight_tensor = torch.tensor( ++ weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) +- p3 = ( +- self.pos_embed(torch.tensor(idx_list[3], dtype=torch.long, device=device)) +- * torch.tensor(weight_list[3], dtype=dtype, device=device)[:, None] ++ pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] ++ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ++ ++ patch_pos_embeds = patch_pos_embeds.split( ++ [h * w for h, w in zip(grid_hs, grid_ws)] + ) + +- patch_pos_embeds = p0 + p1 + p2 + p3 +- patch_pos_embeds = patch_pos_embeds.split([t * h * w for t, h, w in grid_thw]) + patch_pos_embeds_permute = [] +- m_size = self.spatial_merge_size +- for pos_embed, (t, h, w) in zip(patch_pos_embeds, grid_thw): ++ merge_size = self.spatial_merge_size ++ for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): ++ pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( +- pos_embed.view(t, h // m_size, m_size, w // m_size, m_size, -1) ++ pos_embed.view( ++ t, h // merge_size, merge_size, w // merge_size, merge_size, -1 ++ ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) +@@ -555,21 +536,27 @@ class Qwen3LLMModel(Qwen3Model): + hidden_states + residual if residual is not None else hidden_states + ) + ++ deepstack_embeds = None ++ if input_deepstack_embeds is not None: ++ prev_layer_idx = layer_idx - 1 ++ if prev_layer_idx in self.deepstack_embed_to_decoder_layer: ++ sep = self.hidden_size * prev_layer_idx ++ deepstack_embeds = input_deepstack_embeds[ ++ :, sep : sep + self.hidden_size ++ ] ++ ++ # SGLang applies residual at the START of the next layer, not at the END like HuggingFace. ++ # See: https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L549 ++ # To match HF behavior, deepstack must be added AFTER residual: (hidden_states + residual) + deepstack ++ # The order matters because addition with different tensors is not associative in practice. + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, ++ post_residual_addition=deepstack_embeds, + ) + +- # process deepstack +- if ( +- input_deepstack_embeds is not None +- and layer_idx in self.deepstack_embed_to_decoder_layer +- ): +- sep = self.hidden_size * layer_idx +- hidden_states += input_deepstack_embeds[:, sep : sep + self.hidden_size] +- + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { +diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py +index 4474f62d5..0e537c398 100644 +--- a/python/sglang/srt/models/step3_vl.py ++++ b/python/sglang/srt/models/step3_vl.py +@@ -129,6 +129,7 @@ class Step3TextMoEMLP(nn.Module): + top_k=config.moe_top_k, + renormalize=config.norm_expert_weight, + use_grouped_topk=False, ++ layer_id=layer_id, + ) + + self.experts = get_moe_impl_class(quant_config)( +diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py +index 370aec2b6..47666d8f3 100644 +--- a/python/sglang/srt/multimodal/processors/base_processor.py ++++ b/python/sglang/srt/multimodal/processors/base_processor.py +@@ -13,6 +13,7 @@ from PIL import Image + from transformers import BaseImageProcessorFast + + from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem ++from sglang.srt.server_args import get_global_server_args + from sglang.srt.utils import ( + get_bool_env_var, + is_npu, +@@ -260,7 +261,9 @@ class BaseMultimodalProcessor(ABC): + and isinstance(processor.image_processor, BaseImageProcessorFast) + and not self.server_args.disable_fast_image_processor + ): +- if not _is_npu: ++ if get_global_server_args().rl_on_policy_target is not None: ++ kwargs["device"] = "cpu" ++ elif not _is_npu: + kwargs["device"] = "cuda" + elif processor.__class__.__name__ not in { + "Qwen2_5_VLProcessor", +diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py +index 8e7753dab..323788f39 100644 +--- a/python/sglang/srt/server_args.py ++++ b/python/sglang/srt/server_args.py +@@ -535,6 +535,7 @@ class ServerArgs: + disable_fast_image_processor: bool = False + keep_mm_feature_on_device: bool = False + enable_return_hidden_states: bool = False ++ enable_return_routed_experts: bool = False + scheduler_recv_interval: int = 1 + numa_node: Optional[List[int]] = None + enable_deterministic_inference: bool = False +@@ -1966,6 +1967,9 @@ class ServerArgs: + "Enable deterministic inference because of rl_on_policy_target." + ) + self.enable_deterministic_inference = True ++ ++ # For VLM ++ os.environ["SGLANG_VLM_CACHE_SIZE_MB"] = "0" + # TODO remove this environment variable as a whole + os.environ["SGLANG_ENABLE_DETERMINISTIC_INFERENCE"] = "1" + +@@ -3705,6 +3709,11 @@ class ServerArgs: + action="store_true", + help="Enable returning hidden states with responses.", + ) ++ parser.add_argument( ++ "--enable-return-routed-experts", ++ action="store_true", ++ help="Enable returning routed experts of each layer with responses.", ++ ) + parser.add_argument( + "--scheduler-recv-interval", + type=int, +diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py +index b3d72df05..ddfe0b178 100644 +--- a/python/sglang/srt/speculative/eagle_info.py ++++ b/python/sglang/srt/speculative/eagle_info.py +@@ -746,6 +746,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] ++ if self.accept_length is not None: ++ self.accept_length = self.accept_length[: len(new_indices)] ++ if self.accept_length_cpu is not None: ++ self.accept_length_cpu = self.accept_length_cpu[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] +@@ -777,6 +781,27 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) ++ if self.accept_length is not None and spec_info.accept_length is not None: ++ self.accept_length = torch.cat( ++ [self.accept_length, spec_info.accept_length] ++ ) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif self.accept_length is not None: ++ zeros = torch.zeros( ++ [spec_info.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([self.accept_length, zeros]) ++ self.accept_length_cpu = self.accept_length.tolist() ++ elif spec_info.accept_length is not None: ++ zeros = torch.zeros( ++ [self.verified_id.shape[0]], ++ dtype=self.accept_length.dtype, ++ device=self.accept_length.device, ++ ) ++ self.accept_length = torch.cat([zeros, spec_info.accept_length]) ++ self.accept_length_cpu = self.accept_length.tolist() + + + @dataclass diff --git a/docker/version.txt b/docker/version.txt index 40fee9d81..b480e0254 100644 --- a/docker/version.txt +++ b/docker/version.txt @@ -1 +1 @@ -nightly-dev-20251209d \ No newline at end of file +nightly-dev-20251212a \ No newline at end of file diff --git a/examples/geo3k_vlm/run_geo3k_vlm.py b/examples/geo3k_vlm/run_geo3k_vlm.py index 6f5a9c59e..f1d6adfa4 100644 --- a/examples/geo3k_vlm/run_geo3k_vlm.py +++ b/examples/geo3k_vlm/run_geo3k_vlm.py @@ -8,7 +8,6 @@ NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "1")) EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) -MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1") def prepare(): @@ -40,7 +39,7 @@ def execute(): ) eval_args = ( - # "--eval-interval 20 " + "--eval-interval 20 " "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 4096 " @@ -119,27 +118,6 @@ def execute(): # f"{true_on_policy_args} " ) - # Kill existing processes - U.exec_command( - "pkill -9 sglang; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'ray stop --force; '}" - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "pkill -9 redis; " - "true; " - ) - - if not EXTERNAL_RAY: - # Start Ray - U.exec_command( - f"export PYTHONBUFFERED=16 && " - f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} " - f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" - ) # Submit Ray job execute_train( train_args=train_args, diff --git a/examples/true_on_policy_vlm/README.md b/examples/true_on_policy_vlm/README.md index 246665311..786ec2c98 100644 --- a/examples/true_on_policy_vlm/README.md +++ b/examples/true_on_policy_vlm/README.md @@ -5,6 +5,7 @@ This example demonstrates true on-policy training with Qwen3-VL dense model on F

Training Inference Log Prob Diff

+ ## Usage ```bash diff --git a/examples/true_on_policy_vlm/run_simple.py b/examples/true_on_policy_vlm/run_simple.py index 89b8ec229..3fd6013bd 100644 --- a/examples/true_on_policy_vlm/run_simple.py +++ b/examples/true_on_policy_vlm/run_simple.py @@ -8,7 +8,6 @@ NUM_GPUS = int(os.environ.get("MILES_SCRIPT_NUM_GPUS", "1")) EXTERNAL_RAY = int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")) -MASTER_ADDR = os.environ.get("MASTER_ADDR", "127.0.0.1") def prepare(): @@ -39,7 +38,7 @@ def execute(): ) eval_args = ( - # "--eval-interval 20 " + "--eval-interval 20 " "--eval-prompt-data geo3k /root/datasets/geo3k_imgurl/test.parquet " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 4096 " @@ -127,28 +126,6 @@ def execute(): f"{true_on_policy_args} " ) - # Kill existing processes - U.exec_command( - "pkill -9 sglang; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'ray stop --force; '}" - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "sleep 3; " - f"{'' if EXTERNAL_RAY else 'pkill -9 ray; '}" - "pkill -9 miles; " - "pkill -9 redis; " - "true; " - ) - - if not EXTERNAL_RAY: - # Start Ray - U.exec_command( - f"export PYTHONBUFFERED=16 && " - f"ray start --head --node-ip-address {MASTER_ADDR} --num-gpus {NUM_GPUS} " - f"--disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265" - ) - # Submit Ray job execute_train( train_args=train_args, diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 4ff6988f1..1e3e5b3ae 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -80,7 +80,8 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty if i == dist.get_rank(): self.hf_config = AutoConfig.from_pretrained(self.args.hf_checkpoint, trust_remote_code=True) self.tokenizer = load_tokenizer(self.args.hf_checkpoint, trust_remote_code=True) - if self.args.multimodal_keys: + # Vision models have `vision_config` in the config + if hasattr(self.hf_config, "vision_config"): self.processor = load_processor(self.args.hf_checkpoint, trust_remote_code=True) dist.barrier(group=get_gloo_group()) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index 3b60157e1..c2eb949ef 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -71,6 +71,12 @@ def init( self.train_parallel_config = { "dp_size": mpu.get_data_parallel_world_size(with_context_parallel=False), } + dist.barrier(group=get_gloo_group()) + + if args.offload_train: + if (x := args.train_memory_margin_bytes) > 0: + logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") + torch_memory_saver.memory_margin_bytes = x if self.args.debug_rollout_only: return 0 diff --git a/miles/backends/megatron_utils/initialize.py b/miles/backends/megatron_utils/initialize.py index 7c4357436..e9f062c11 100644 --- a/miles/backends/megatron_utils/initialize.py +++ b/miles/backends/megatron_utils/initialize.py @@ -7,7 +7,6 @@ from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator from megatron.training.global_vars import _build_tokenizer, set_args - logger = logging.getLogger(__name__) diff --git a/miles/ray/train_actor.py b/miles/ray/train_actor.py index 4244c022a..3d3923e9c 100644 --- a/miles/ray/train_actor.py +++ b/miles/ray/train_actor.py @@ -7,7 +7,6 @@ import ray import torch import torch.distributed as dist -from torch_memory_saver import torch_memory_saver import miles.utils.eval_config from miles.ray.ray_actor import RayActor @@ -53,11 +52,6 @@ def init(self, args, role, with_ref=False): self.role = role self.with_ref = with_ref - if (x := args.train_memory_margin_bytes) > 0: - logger.info(f"Set torch_memory_saver.memory_margin_bytes to {x}") - assert args.offload_train - torch_memory_saver.memory_margin_bytes = x - torch.serialization.add_safe_globals([miles.utils.eval_config.EvalDatasetConfig]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) diff --git a/miles/utils/flops_utils.py b/miles/utils/flops_utils.py index 57cf9de90..71cdd4c65 100644 --- a/miles/utils/flops_utils.py +++ b/miles/utils/flops_utils.py @@ -16,8 +16,8 @@ def calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num def calculate_attention_flops(seqlen, num_attention_heads, head_dim): - # QK^T - flops = 2 * num_attention_heads * seqlen * seqlen * head_dim + # QK^T with causal + flops = 2 * num_attention_heads * seqlen * seqlen * head_dim // 2 # A*V flops += 2 * num_attention_heads * seqlen * seqlen * head_dim return flops @@ -31,8 +31,9 @@ def calculate_mlp_flops(seqlen, hidden_size, ffn_hidden_size): return 2 * seqlen * hidden_size * ffn_hidden_size * 3 -def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size): - head_dim = hidden_size // num_attention_heads +def calculate_layer_flops(seqlen, hidden_size, num_attention_heads, num_query_groups, ffn_hidden_size, head_dim): + if head_dim is None: + head_dim = hidden_size // num_attention_heads return ( calculate_qkv_projection_flops(seqlen, hidden_size, num_attention_heads, num_query_groups) + calculate_attention_flops(seqlen, num_attention_heads, head_dim) @@ -49,6 +50,7 @@ def calculate_fwd_flops( num_attention_heads = args.num_attention_heads num_query_groups = args.num_query_groups vocab_size = args.vocab_size + kv_channels = args.kv_channels total_flops = 0 @@ -82,6 +84,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, dense_ffn, + kv_channels, ) * num_dense_layers ) @@ -94,6 +97,7 @@ def calculate_fwd_flops( num_attention_heads, num_query_groups, moe_ffn, + kv_channels, ) * num_moe_layers ) diff --git a/setup.py b/setup.py index 22c18d712..8c0794320 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ def get_tag(self): setup( author="miles Team", name="miles", - version="0.2.0.post1", + version="0.2.1", packages=find_packages(include=["miles*", "miles_plugins*"]), include_package_data=True, install_requires=_fetch_requirements("requirements.txt"), From 5c365db19aedcc5447267ac1c551fc00bc1c70fd Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 14 Dec 2025 13:56:32 +0800 Subject: [PATCH 04/21] Super tiny update link (#312) --- docs/en/get_started/quick_start.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index b562547fb..7c44bbd12 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -571,5 +571,5 @@ ray job submit --address="http://127.0.0.1:8265" \ miles has been deeply optimized for distributed training of large-scale Mixture of Experts (MoE) models. We provide some end-to-end training cases for reference: -- [Example: 64xH100 Training GLM-4.5](models/glm4.5-355B-A32B.md) -- [Example: 128xH100 Training DeepSeek-R1](models/deepseek-r1.md) +- [Example: 64xH100 Training GLM-4.5](../examples/glm4.5-355B-A32B.md) +- [Example: 128xH100 Training DeepSeek-R1](../examples/deepseek-r1.md) From 520216497f9f0ba8833cbb7eeed5308bcef36878 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 14 Dec 2025 14:04:32 +0800 Subject: [PATCH 05/21] Tiny update doc about multi node training (#313) --- docs/en/get_started/quick_start.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 7c44bbd12..7387e9ead 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -569,7 +569,17 @@ ray job submit --address="http://127.0.0.1:8265" \ --... # Other Megatron/SGLang/miles arguments ``` +Optionally, the following environment variables may be needed based on your environment. For example, when there are multiple IPs and the wrong one is chosen in a Docker or SLURM envionment. We provide an example used in a SLURM + enroot multi-node system as follows: + +``` +export MILES_HOST_IP=$(hostname -I | awk '{print $1}') +export GLOO_SOCKET_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\\./ {print $2}') +export NCCL_SOCKET_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\\./ {print $2}') +export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME=$(ip -o -4 addr show | awk '$4 ~ /^10\./ {print $2}') +``` + miles has been deeply optimized for distributed training of large-scale Mixture of Experts (MoE) models. We provide some end-to-end training cases for reference: - [Example: 64xH100 Training GLM-4.5](../examples/glm4.5-355B-A32B.md) - [Example: 128xH100 Training DeepSeek-R1](../examples/deepseek-r1.md) +- The scripts such as `scripts/run_qwen3_30b_a3b.py`, `scripts/run_glm45_355b_a32b.py` also support multi-node training, though there are little documentations about it currently. From bdc41f00d048c12768bfb0a431279eb5eb2fb81a Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Fri, 19 Dec 2025 12:32:56 -0800 Subject: [PATCH 06/21] add explicit argument name for new megatron compatibility (#324) --- .../backends/megatron_utils/model_provider.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/miles/backends/megatron_utils/model_provider.py b/miles/backends/megatron_utils/model_provider.py index 04b3db1b6..5b7b3dd74 100644 --- a/miles/backends/megatron_utils/model_provider.py +++ b/miles/backends/megatron_utils/model_provider.py @@ -89,19 +89,19 @@ def model_provider(pre_process: bool = True, post_process: bool = True, vp_stage # Define the decoder layer spec if use_te: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) else: transformer_layer_spec = get_gpt_layer_local_spec( - args.num_experts, - args.moe_grouped_gemm, - args.qk_layernorm, - args.multi_latent_attention, - args.moe_use_legacy_grouped_gemm, + num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm, + qk_layernorm=args.qk_layernorm, + multi_latent_attention=args.multi_latent_attention, + moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, ) build_model_context = nullcontext From 996e31d69757e9a318a6cacd5f633ca8c5dd4926 Mon Sep 17 00:00:00 2001 From: zijiexia <37504505+zijiexia@users.noreply.github.com> Date: Mon, 22 Dec 2025 20:28:01 -0800 Subject: [PATCH 07/21] update outdated commands in docs (#339) --- docs/en/examples/deepseek-r1.md | 2 +- docs/en/examples/glm4-9B.md | 6 +++--- docs/en/examples/qwen3-30B-A3B.md | 2 +- docs/en/examples/qwen3-4B.md | 6 +++--- docs/en/get_started/quick_start.md | 9 +++------ 5 files changed, 11 insertions(+), 14 deletions(-) diff --git a/docs/en/examples/deepseek-r1.md b/docs/en/examples/deepseek-r1.md index 28f2f88f1..e1c24e3ad 100644 --- a/docs/en/examples/deepseek-r1.md +++ b/docs/en/examples/deepseek-r1.md @@ -16,7 +16,7 @@ For instructions on setting up the environment and downloading data, please refe To prepare the DeepSeek R1 checkpoint, first you will need to download DeepSeek-R1 to a directory accessible by all machines (hereinafter referred to as `$BASE_DIR`): ```bash -huggingface-cli download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 +hf download deepseek-ai/DeepSeek-R1 --local-dir $BASE_DIR/DeepSeek-R1 ``` The Hugging Face checkpoint for DeepSeek-R1 is in a block-quantized fp8 format. To convert it into a torch_dist format that Megatron can load, you first need to convert it to a bf16 Hugging Face checkpoint: diff --git a/docs/en/examples/glm4-9B.md b/docs/en/examples/glm4-9B.md index 5c4917e95..8024fb48a 100644 --- a/docs/en/examples/glm4-9B.md +++ b/docs/en/examples/glm4-9B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 +hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` diff --git a/docs/en/examples/qwen3-30B-A3B.md b/docs/en/examples/qwen3-30B-A3B.md index e51de0d6d..965ef7eb5 100644 --- a/docs/en/examples/qwen3-30B-A3B.md +++ b/docs/en/examples/qwen3-30B-A3B.md @@ -79,7 +79,7 @@ Here, we will briefly introduce the MoE-related parts in the [run-qwen3-30B-A3B. miles also supports BF16 training with FP8 inference. For the Qwen3-30B-A3B model, you just need to download the following model: ```bash -huggingface-cli download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 +hf download Qwen/Qwen3-30B-A3B-FP8 --local-dir /root/Qwen3-30B-A3B-FP8 ``` And replace `--hf-checkpoint` with: diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index 27d2a4c1c..cb76a2f1f 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -15,14 +15,14 @@ Download the model and data: ```bash # hf checkpoint -huggingface-cli download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B +hf download Qwen/Qwen3-4B --local-dir /root/Qwen3-4B # train data -huggingface-cli download --repo-type dataset zhuzilin/dapo-math-17k \ +hf download --repo-type dataset zhuzilin/dapo-math-17k \ --local-dir /root/dapo-math-17k # eval data -huggingface-cli download --repo-type dataset zhuzilin/aime-2024 \ +hf download --repo-type dataset zhuzilin/aime-2024 \ --local-dir /root/aime-2024 ``` diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index 7387e9ead..c065662ab 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -39,13 +39,12 @@ docker run --rm --gpus all --ipc=host --shm-size=16g \ ### Install miles -After entering the Docker container, please follow these steps to clone the miles repository and install it: +miles is already installed in the docker image. To update to the latest version, please execute the following command: ```bash # Path can be adjusted according to actual situation -cd /root/ -git clone https://github.com/radixark/miles.git -cd miles +cd /root/miles +git pull pip install -e . ``` @@ -54,8 +53,6 @@ pip install -e . You can download required models and datasets from platforms like Hugging Face, ModelScope, etc. Here are the commands to download example resources using `huggingface_hub`: ```bash -pip install -U huggingface_hub - # Download model weights (GLM-Z1-9B) hf download zai-org/GLM-Z1-9B-0414 --local-dir /root/GLM-Z1-9B-0414 From 7ca7026cb82b6daff1a8c53a26d60f8840c5f991 Mon Sep 17 00:00:00 2001 From: Zhuohao Li Date: Tue, 23 Dec 2025 15:30:40 -0800 Subject: [PATCH 08/21] tiny fix (#337) Co-authored-by: Li, Zhuohao --- scripts/run-glm4-9B-4xgpu-radixtree.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/run-glm4-9B-4xgpu-radixtree.sh b/scripts/run-glm4-9B-4xgpu-radixtree.sh index dbdc01c22..4a7d1d7b9 100755 --- a/scripts/run-glm4-9B-4xgpu-radixtree.sh +++ b/scripts/run-glm4-9B-4xgpu-radixtree.sh @@ -17,8 +17,6 @@ export PYTHONBUFFERED=16 export CUDA_VISIBLE_DEVICES=0,1,2,3 -WANDB_KEY=8920a59faeab83c97b55c3cbe78618f11d0a1821 - NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l) if [ "$NVLINK_COUNT" -gt 0 ]; then HAS_NVLINK=1 From 8626480be6f983e0c792714322bfa1b8d348b922 Mon Sep 17 00:00:00 2001 From: Yisheng Gong Date: Tue, 23 Dec 2025 16:30:15 -0800 Subject: [PATCH 09/21] Fix example rollout_temperature and top_k (#338) --- docs/en/examples/glm4-9B.md | 4 ++-- docs/en/examples/qwen3-4B.md | 4 ++-- docs/en/get_started/quick_start.md | 4 ++-- docs/en/platform_support/amd_tutorial.md | 4 ++-- examples/eval/scripts/run-qwen3-32B.sh | 2 +- examples/eval/scripts/run-qwen3-4B.sh | 2 +- examples/eval_multi_task/multi_task.sh | 2 +- examples/formal_math/single_round/run.py | 4 ++-- examples/formal_math/single_round/run_minimal.py | 4 ++-- examples/fully_async/run-qwen3-4b-fully_async.sh | 2 +- examples/geo3k_vlm/run_geo3k_vlm.py | 2 +- examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh | 4 ++-- examples/low_precision/run-qwen3-4b-fp8.sh | 4 ++-- examples/multi_agent/README.md | 2 +- examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh | 4 ++-- examples/on_policy_distillation/run-qwen3-8B-opd.sh | 4 ++-- examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh | 2 +- examples/retool/retool_qwen3_4b_rl.sh | 4 ++-- examples/strands-agents/strands_qwen3_4b.sh | 4 ++-- examples/tau-bench/run_qwen3_4B.sh | 2 +- examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh | 4 ++-- examples/true_on_policy/run_simple.py | 2 +- examples/true_on_policy_vlm/run_simple.py | 2 +- miles/utils/debug_utils/send_to_sglang.py | 2 +- scripts/run-deepseek-r1.sh | 4 ++-- scripts/run-glm4-9B-4xgpu-radixtree.sh | 4 ++-- scripts/run-glm4-9B.sh | 4 ++-- scripts/run-glm4.5-355B-A32B.sh | 4 ++-- scripts/run-kimi-k2-Instruct.sh | 4 ++-- scripts/run-kimi-k2-Thinking.sh | 4 ++-- scripts/run-llama3.2-3B-Instruct-amd.sh | 4 ++-- scripts/run-mimo-7B-rl-eagle.sh | 4 ++-- scripts/run-moonlight-16B-A3B.sh | 4 ++-- scripts/run-qwen3-235B-A22B.sh | 4 ++-- scripts/run-qwen3-30B-A3B.sh | 4 ++-- scripts/run-qwen3-32B.sh | 4 ++-- scripts/run-qwen3-4B-amd.sh | 4 ++-- scripts/run-qwen3-4B-fsdp.sh | 2 +- scripts/run-qwen3-4B.sh | 4 ++-- scripts/run-qwen3-4B_4xgpu-radixtree.sh | 4 ++-- scripts/run-qwen3-4B_4xgpu.sh | 4 ++-- scripts/run-qwen3-8B-amd.sh | 4 ++-- scripts/run-qwen3-next-80B-A3B.sh | 4 ++-- scripts/run_deepseek.py | 4 ++-- scripts/run_glm45_355b_a32b.py | 4 ++-- scripts/run_mcore_fsdp.py | 4 ++-- scripts/run_qwen3_30b_a3b.py | 4 ++-- scripts/run_qwen3_4b.py | 4 ++-- tests/test_external_rollout.py | 2 +- tests/test_quick_start_glm4_9B.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 2 +- tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 2 +- tests/test_qwen3_0.6B_fsdp_distributed.py | 2 +- tests/test_qwen3_30B_A3B.py | 2 +- 55 files changed, 91 insertions(+), 91 deletions(-) diff --git a/docs/en/examples/glm4-9B.md b/docs/en/examples/glm4-9B.md index 8024fb48a..f46e9f373 100644 --- a/docs/en/examples/glm4-9B.md +++ b/docs/en/examples/glm4-9B.md @@ -110,7 +110,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Number of training steps corresponding to one rollout --num-steps-per-rollout 1 @@ -129,7 +129,7 @@ EVAL_ARGS=( --eval-prompt-data /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) ``` diff --git a/docs/en/examples/qwen3-4B.md b/docs/en/examples/qwen3-4B.md index cb76a2f1f..1966fd823 100644 --- a/docs/en/examples/qwen3-4B.md +++ b/docs/en/examples/qwen3-4B.md @@ -110,7 +110,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Number of training steps corresponding to one rollout --num-steps-per-rollout 1 @@ -129,7 +129,7 @@ EVAL_ARGS=( --eval-prompt-data /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) ``` diff --git a/docs/en/get_started/quick_start.md b/docs/en/get_started/quick_start.md index c065662ab..db07ab705 100644 --- a/docs/en/get_started/quick_start.md +++ b/docs/en/get_started/quick_start.md @@ -200,7 +200,7 @@ ROLLOUT_ARGS=( # Rollout sampling parameters --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 # Load balancing for data collected in rollout phase. It ensures that the computational workload allocated to each training process (DP rank) is roughly equal, which may be beneficial for training speed --balance-data @@ -222,7 +222,7 @@ EVAL_ARGS=( # Maximum response length during evaluation --eval-max-response-len 16384 # Sampling parameters during evaluation - --eval-top-p 0.7 + --eval-top-p 1 ) ``` diff --git a/docs/en/platform_support/amd_tutorial.md b/docs/en/platform_support/amd_tutorial.md index 6c2def2d9..790aede5d 100644 --- a/docs/en/platform_support/amd_tutorial.md +++ b/docs/en/platform_support/amd_tutorial.md @@ -151,7 +151,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -162,7 +162,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/eval/scripts/run-qwen3-32B.sh b/examples/eval/scripts/run-qwen3-32B.sh index 0a235aeea..eb6702deb 100644 --- a/examples/eval/scripts/run-qwen3-32B.sh +++ b/examples/eval/scripts/run-qwen3-32B.sh @@ -54,7 +54,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/eval/scripts/run-qwen3-4B.sh b/examples/eval/scripts/run-qwen3-4B.sh index 64b8c5cb1..4343377f1 100644 --- a/examples/eval/scripts/run-qwen3-4B.sh +++ b/examples/eval/scripts/run-qwen3-4B.sh @@ -54,7 +54,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/eval_multi_task/multi_task.sh b/examples/eval_multi_task/multi_task.sh index eb028d91e..8236e6d2f 100644 --- a/examples/eval_multi_task/multi_task.sh +++ b/examples/eval_multi_task/multi_task.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/formal_math/single_round/run.py b/examples/formal_math/single_round/run.py index c688fe341..8cbb9d738 100644 --- a/examples/formal_math/single_round/run.py +++ b/examples/formal_math/single_round/run.py @@ -57,7 +57,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -77,7 +77,7 @@ def execute(): "--eval-interval 20 " "--n-samples-per-eval-prompt 1 " f"--eval-max-response-len {eval_max_response_len or 16384} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) if mode == "eval_flc": diff --git a/examples/formal_math/single_round/run_minimal.py b/examples/formal_math/single_round/run_minimal.py index dad439c2b..469d6b833 100644 --- a/examples/formal_math/single_round/run_minimal.py +++ b/examples/formal_math/single_round/run_minimal.py @@ -33,7 +33,7 @@ "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " "--num-rollout 3000 " @@ -43,7 +43,7 @@ "--eval-interval 20 " "--n-samples-per-eval-prompt 1 " "--eval-max-response-len 16384 " - "--eval-top-p 0.7 " + "--eval-top-p 1 " "--eval-prompt-data " "minif2f /root/datasets/formal_math_single_round/minimal_demo/minif2f_test.jsonl " ) diff --git a/examples/fully_async/run-qwen3-4b-fully_async.sh b/examples/fully_async/run-qwen3-4b-fully_async.sh index 63e7e7818..026e48608 100644 --- a/examples/fully_async/run-qwen3-4b-fully_async.sh +++ b/examples/fully_async/run-qwen3-4b-fully_async.sh @@ -52,7 +52,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/geo3k_vlm/run_geo3k_vlm.py b/examples/geo3k_vlm/run_geo3k_vlm.py index f1d6adfa4..0106d2beb 100644 --- a/examples/geo3k_vlm/run_geo3k_vlm.py +++ b/examples/geo3k_vlm/run_geo3k_vlm.py @@ -34,7 +34,7 @@ def execute(): "--rollout-batch-size 64 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 4096 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 512 " ) diff --git a/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh b/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh index 291fd8cb9..e761a3acd 100644 --- a/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh +++ b/examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh @@ -49,7 +49,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 16 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 128 --balance-data @@ -60,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime "${BASE_DIR}/aime-2024.jsonl" --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/low_precision/run-qwen3-4b-fp8.sh b/examples/low_precision/run-qwen3-4b-fp8.sh index 22e7d2e1f..b196ba606 100644 --- a/examples/low_precision/run-qwen3-4b-fp8.sh +++ b/examples/low_precision/run-qwen3-4b-fp8.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/data/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/multi_agent/README.md b/examples/multi_agent/README.md index 0b6bf49fa..c974640bc 100644 --- a/examples/multi_agent/README.md +++ b/examples/multi_agent/README.md @@ -45,7 +45,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-context-len 16384 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data diff --git a/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh b/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh index f8e3952da..f3e5f1466 100644 --- a/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh +++ b/examples/multi_agent/run-qwen3-30B-A3B-multi-agent.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --n-samples-per-prompt 8 --rollout-max-context-len 16384 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -60,7 +60,7 @@ EVAL_ARGS=( # --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/on_policy_distillation/run-qwen3-8B-opd.sh b/examples/on_policy_distillation/run-qwen3-8B-opd.sh index bb63a8513..c57b9eef4 100644 --- a/examples/on_policy_distillation/run-qwen3-8B-opd.sh +++ b/examples/on_policy_distillation/run-qwen3-8B-opd.sh @@ -63,7 +63,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 16 --n-samples-per-prompt 4 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 --balance-data @@ -80,7 +80,7 @@ EVAL_ARGS=( # --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl # --n-samples-per-eval-prompt 16 # --eval-max-response-len 16384 - # --eval-top-p 0.7 + # --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh index 6cf2af46c..9fab2cd2f 100644 --- a/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh +++ b/examples/reproducibility/run-qwen2.5-0.5B-gsm8k.sh @@ -34,7 +34,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 1024 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 ) diff --git a/examples/retool/retool_qwen3_4b_rl.sh b/examples/retool/retool_qwen3_4b_rl.sh index e3e8f0b2e..838ce0e2c 100644 --- a/examples/retool/retool_qwen3_4b_rl.sh +++ b/examples/retool/retool_qwen3_4b_rl.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/strands-agents/strands_qwen3_4b.sh b/examples/strands-agents/strands_qwen3_4b.sh index 2df6ff0dd..647c8e2f5 100644 --- a/examples/strands-agents/strands_qwen3_4b.sh +++ b/examples/strands-agents/strands_qwen3_4b.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/data/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/tau-bench/run_qwen3_4B.sh b/examples/tau-bench/run_qwen3_4B.sh index 25eb22545..a13067d0c 100644 --- a/examples/tau-bench/run_qwen3_4B.sh +++ b/examples/tau-bench/run_qwen3_4B.sh @@ -42,7 +42,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 1024 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std --balance-data diff --git a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh index d8cac2e87..300e8ac75 100644 --- a/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh +++ b/examples/train_infer_mismatch_helper/run-qwen3-4b-mis.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 1 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/examples/true_on_policy/run_simple.py b/examples/true_on_policy/run_simple.py index 6112877f9..1b472b806 100644 --- a/examples/true_on_policy/run_simple.py +++ b/examples/true_on_policy/run_simple.py @@ -32,7 +32,7 @@ def execute(): f"--rollout-batch-size {1 if MODE == 'debug_one_sample' else 32} " f"--n-samples-per-prompt {1 if MODE == 'debug_one_sample' else 8} " f"--rollout-max-response-len {2 if MODE == 'debug_one_sample' else 1024} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # temp remove this to make test easier # "--over-sampling-batch-size 64 " # "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " diff --git a/examples/true_on_policy_vlm/run_simple.py b/examples/true_on_policy_vlm/run_simple.py index 3fd6013bd..3f6e17541 100644 --- a/examples/true_on_policy_vlm/run_simple.py +++ b/examples/true_on_policy_vlm/run_simple.py @@ -33,7 +33,7 @@ def execute(): "--rollout-batch-size 64 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 4096 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 512 " ) diff --git a/miles/utils/debug_utils/send_to_sglang.py b/miles/utils/debug_utils/send_to_sglang.py index c183edf79..0c2d4b57d 100644 --- a/miles/utils/debug_utils/send_to_sglang.py +++ b/miles/utils/debug_utils/send_to_sglang.py @@ -22,7 +22,7 @@ def main( Minimally send prompts to SGLang using OpenAI endpoints with arguments in the same format as main Miles. Example usage: - python -m miles.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 0.8 --rollout-top-p 0.7 + python -m miles.utils.debug_utils.send_to_sglang --prompt-data /root/datasets/aime-2024/aime-2024.jsonl --input-key prompt --n-samples-per-prompt 16 --rollout-max-response-len 32768 --rollout-temperature 1 --rollout-top-p 1 """ async def _main_async(): diff --git a/scripts/run-deepseek-r1.sh b/scripts/run-deepseek-r1.sh index f8a3779b4..93e6c0f4b 100644 --- a/scripts/run-deepseek-r1.sh +++ b/scripts/run-deepseek-r1.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -60,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4-9B-4xgpu-radixtree.sh b/scripts/run-glm4-9B-4xgpu-radixtree.sh index 4a7d1d7b9..bd14b6d2c 100755 --- a/scripts/run-glm4-9B-4xgpu-radixtree.sh +++ b/scripts/run-glm4-9B-4xgpu-radixtree.sh @@ -49,7 +49,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -60,7 +60,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4-9B.sh b/scripts/run-glm4-9B.sh index df613f187..b67523883 100644 --- a/scripts/run-glm4-9B.sh +++ b/scripts/run-glm4-9B.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-glm4.5-355B-A32B.sh b/scripts/run-glm4.5-355B-A32B.sh index 8771d1b77..0deaf0b88 100644 --- a/scripts/run-glm4.5-355B-A32B.sh +++ b/scripts/run-glm4.5-355B-A32B.sh @@ -42,7 +42,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-kimi-k2-Instruct.sh b/scripts/run-kimi-k2-Instruct.sh index 723432d9b..86919eff4 100644 --- a/scripts/run-kimi-k2-Instruct.sh +++ b/scripts/run-kimi-k2-Instruct.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 32768 - --rollout-temperature 0.8 + --rollout-temperature 1 # --global-batch-size 1024 @@ -64,7 +64,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 32768 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-kimi-k2-Thinking.sh b/scripts/run-kimi-k2-Thinking.sh index b559599ae..e5006b3b5 100644 --- a/scripts/run-kimi-k2-Thinking.sh +++ b/scripts/run-kimi-k2-Thinking.sh @@ -48,7 +48,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 # --global-batch-size 1024 @@ -64,7 +64,7 @@ EVAL_ARGS=( --eval-prompt-data aime $BASE_DIR/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-llama3.2-3B-Instruct-amd.sh b/scripts/run-llama3.2-3B-Instruct-amd.sh index 1e24a0660..a3036c304 100644 --- a/scripts/run-llama3.2-3B-Instruct-amd.sh +++ b/scripts/run-llama3.2-3B-Instruct-amd.sh @@ -62,7 +62,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 16384 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -73,7 +73,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-mimo-7B-rl-eagle.sh b/scripts/run-mimo-7B-rl-eagle.sh index dec29d64d..092f25fef 100644 --- a/scripts/run-mimo-7B-rl-eagle.sh +++ b/scripts/run-mimo-7B-rl-eagle.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 1 --eval-max-response-len 8192 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-moonlight-16B-A3B.sh b/scripts/run-moonlight-16B-A3B.sh index 28234086a..ef695d398 100644 --- a/scripts/run-moonlight-16B-A3B.sh +++ b/scripts/run-moonlight-16B-A3B.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 128 --n-samples-per-prompt 8 --rollout-max-response-len 4096 - --rollout-temperature 0.8 + --rollout-temperature 1 --over-sampling-batch-size 256 --dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std @@ -61,7 +61,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 8 --eval-max-response-len 4096 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-235B-A22B.sh b/scripts/run-qwen3-235B-A22B.sh index 7f4892ac6..e42e17ab2 100644 --- a/scripts/run-qwen3-235B-A22B.sh +++ b/scripts/run-qwen3-235B-A22B.sh @@ -58,7 +58,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 8 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 --balance-data @@ -69,7 +69,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-30B-A3B.sh b/scripts/run-qwen3-30B-A3B.sh index 5ddfbce6e..19bc70927 100644 --- a/scripts/run-qwen3-30B-A3B.sh +++ b/scripts/run-qwen3-30B-A3B.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -58,7 +58,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-32B.sh b/scripts/run-qwen3-32B.sh index 8bee49577..f6eb8240a 100644 --- a/scripts/run-qwen3-32B.sh +++ b/scripts/run-qwen3-32B.sh @@ -45,7 +45,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -56,7 +56,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B-amd.sh b/scripts/run-qwen3-4B-amd.sh index 6710096fe..321a9712d 100755 --- a/scripts/run-qwen3-4B-amd.sh +++ b/scripts/run-qwen3-4B-amd.sh @@ -55,7 +55,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -66,7 +66,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B-fsdp.sh b/scripts/run-qwen3-4B-fsdp.sh index 42200d182..3c95442d5 100644 --- a/scripts/run-qwen3-4B-fsdp.sh +++ b/scripts/run-qwen3-4B-fsdp.sh @@ -51,7 +51,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 8 --n-samples-per-prompt 8 --rollout-max-response-len 4096 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 64 ) diff --git a/scripts/run-qwen3-4B.sh b/scripts/run-qwen3-4B.sh index f41a9f714..c7f01abd9 100644 --- a/scripts/run-qwen3-4B.sh +++ b/scripts/run-qwen3-4B.sh @@ -46,7 +46,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B_4xgpu-radixtree.sh b/scripts/run-qwen3-4B_4xgpu-radixtree.sh index e16f93d7a..cee7adb6d 100644 --- a/scripts/run-qwen3-4B_4xgpu-radixtree.sh +++ b/scripts/run-qwen3-4B_4xgpu-radixtree.sh @@ -47,7 +47,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data ) @@ -57,7 +57,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-4B_4xgpu.sh b/scripts/run-qwen3-4B_4xgpu.sh index aa893f03b..931ef0408 100755 --- a/scripts/run-qwen3-4B_4xgpu.sh +++ b/scripts/run-qwen3-4B_4xgpu.sh @@ -50,7 +50,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -61,7 +61,7 @@ EVAL_ARGS=( --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-8B-amd.sh b/scripts/run-qwen3-8B-amd.sh index de814aab4..cdc18f5ca 100644 --- a/scripts/run-qwen3-8B-amd.sh +++ b/scripts/run-qwen3-8B-amd.sh @@ -63,7 +63,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -74,7 +74,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${DATA_DIR}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run-qwen3-next-80B-A3B.sh b/scripts/run-qwen3-next-80B-A3B.sh index f18df3690..d5d725124 100644 --- a/scripts/run-qwen3-next-80B-A3B.sh +++ b/scripts/run-qwen3-next-80B-A3B.sh @@ -56,7 +56,7 @@ ROLLOUT_ARGS=( --rollout-batch-size 32 --n-samples-per-prompt 8 --rollout-max-response-len 8192 - --rollout-temperature 0.8 + --rollout-temperature 1 --global-batch-size 256 --balance-data @@ -67,7 +67,7 @@ EVAL_ARGS=( --eval-prompt-data aime ${BASE_FOLDER}/aime-2024/aime-2024.jsonl --n-samples-per-eval-prompt 16 --eval-max-response-len 16384 - --eval-top-p 0.7 + --eval-top-p 1 ) PERF_ARGS=( diff --git a/scripts/run_deepseek.py b/scripts/run_deepseek.py index 2d4a563b5..05fee4447 100644 --- a/scripts/run_deepseek.py +++ b/scripts/run_deepseek.py @@ -124,7 +124,7 @@ def train(args: ScriptArgs): "--num-rollout 3000 " "--rollout-batch-size 128 " "--n-samples-per-prompt 8 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # ------------ "--num-steps-per-rollout 4 " "--balance-data " @@ -139,7 +139,7 @@ def train(args: ScriptArgs): # sometimes disable eval to speed up debugging eval_args = "" if (args.mode != "debug_minimal") and args.enable_eval: - eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " + eval_args += "--eval-interval 20 " "--eval-top-p 1 " match args.task: case "dapo_aime": diff --git a/scripts/run_glm45_355b_a32b.py b/scripts/run_glm45_355b_a32b.py index ca91f9257..8015fdc4d 100644 --- a/scripts/run_glm45_355b_a32b.py +++ b/scripts/run_glm45_355b_a32b.py @@ -134,7 +134,7 @@ def train(args: ScriptArgs): # TODO enlarge "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " # ------------ # TODO enlarge "--num-steps-per-rollout 1 " @@ -151,7 +151,7 @@ def train(args: ScriptArgs): # sometimes disable eval to speed up debugging eval_args = "" if (args.mode != "debug_minimal") and args.enable_eval: - eval_args += "--eval-interval 20 " "--eval-top-p 0.7 " + eval_args += "--eval-interval 20 " "--eval-top-p 1 " match args.task: case "dapo_aime": diff --git a/scripts/run_mcore_fsdp.py b/scripts/run_mcore_fsdp.py index f0c76f91f..354624091 100644 --- a/scripts/run_mcore_fsdp.py +++ b/scripts/run_mcore_fsdp.py @@ -82,7 +82,7 @@ def execute(args: ScriptArgs): f"--rollout-batch-size {8 if args.mode == 'debug_minimal' else 64} " f"--n-samples-per-prompt {8 if args.mode == 'debug_minimal' else 16} " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 32768} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " f"--global-batch-size {64 if args.mode == 'debug_minimal' else 1024} " ) @@ -122,7 +122,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " f"--eval-max-response-len {eval_max_response_len} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) perf_args = ( diff --git a/scripts/run_qwen3_30b_a3b.py b/scripts/run_qwen3_30b_a3b.py index 6e226a613..26a374aa0 100644 --- a/scripts/run_qwen3_30b_a3b.py +++ b/scripts/run_qwen3_30b_a3b.py @@ -77,7 +77,7 @@ def execute(args: ScriptArgs): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -89,7 +89,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " "--eval-max-response-len 16384 " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) perf_args = ( diff --git a/scripts/run_qwen3_4b.py b/scripts/run_qwen3_4b.py index 8e074c2f9..d1aa63301 100644 --- a/scripts/run_qwen3_4b.py +++ b/scripts/run_qwen3_4b.py @@ -95,7 +95,7 @@ def execute(args: ScriptArgs): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " f"--rollout-max-response-len {100 if args.mode == 'debug_minimal' else 8192} " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 256 " "--balance-data " ) @@ -137,7 +137,7 @@ def execute(args: ScriptArgs): "--eval-prompt-data aime /root/datasets/aime-2024/aime-2024.jsonl " "--n-samples-per-eval-prompt 16 " f"--eval-max-response-len {eval_max_response_len} " - "--eval-top-p 0.7 " + "--eval-top-p 1 " ) grpo_args = ( diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index 117246bf1..c5c0838c5 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -33,7 +33,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index bf8453719..f18888c22 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -31,7 +31,7 @@ def execute(): "--rollout-batch-size 8 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 32 " "--balance-data " ) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index d63cbe12c..6302aadb6 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -29,7 +29,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 181d425a6..1c55ccb20 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -28,7 +28,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index dd7bb21df..6967f9145 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -23,7 +23,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index 4d3d8b94a..b3eb416b3 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -27,7 +27,7 @@ def execute(): "--rollout-batch-size 32 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 1024 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--over-sampling-batch-size 64 " "--dynamic-sampling-filter-path miles.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " "--global-batch-size 256 " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index b1eaca583..6b5f6b889 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -35,7 +35,7 @@ def execute(): "--rollout-batch-size 8 " "--n-samples-per-prompt 8 " "--rollout-max-response-len 8192 " - "--rollout-temperature 0.8 " + "--rollout-temperature 1 " "--global-batch-size 32 " "--balance-data " ) From b3dc57d64eb27888197c61861a64c9e4c784cf19 Mon Sep 17 00:00:00 2001 From: Yueming Yuan Date: Tue, 23 Dec 2025 19:11:19 -0600 Subject: [PATCH 10/21] Supported `qkv_format=bshd` with CP (#341) --- miles/backends/megatron_utils/actor.py | 31 ++++++++-- miles/backends/megatron_utils/cp_utils.py | 62 +++++++++++++++---- miles/backends/megatron_utils/data.py | 75 ++++++++++++++--------- miles/backends/megatron_utils/loss.py | 52 +++++++++++++--- miles/backends/megatron_utils/model.py | 8 ++- miles/utils/arguments.py | 17 +++++ 6 files changed, 188 insertions(+), 57 deletions(-) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index c2eb949ef..bcbdd1a42 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -192,18 +192,37 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch: rollout_data["loss_masks"] = [ torch.tensor(t, dtype=torch.int, device=torch.cuda.current_device()) for t in rollout_data["loss_masks"] ] + + if self.args.qkv_format == "bshd": + # TODO: micro-batch wise dynamic, possibly move to @data.py:get_data_iterator + max_seq_len = max(rollout_data["total_lengths"]) + + # pad to reduce memory fragmentation and maybe make the computation faster + pad_size = mpu.get_tensor_model_parallel_world_size() * self.args.data_pad_size_multiplier + max_seq_len = (max_seq_len + pad_size - 1) // pad_size * pad_size + + rollout_data["max_seq_lens"] = [max_seq_len] * len(rollout_data["tokens"]) + if "rollout_log_probs" in rollout_data: rollout_data["rollout_log_probs"] = [ torch.tensor( - slice_log_prob_with_cp(log_prob, total_length, response_length), + slice_log_prob_with_cp( + log_prob, + total_length, + response_length, + self.args.qkv_format, + rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None, + ), device=torch.cuda.current_device(), dtype=torch.float32, ) - for log_prob, total_length, response_length in zip( - rollout_data["rollout_log_probs"], - rollout_data["total_lengths"], - rollout_data["response_lengths"], - strict=False, + for i, (log_prob, total_length, response_length) in enumerate( + zip( + rollout_data["rollout_log_probs"], + rollout_data["total_lengths"], + rollout_data["response_lengths"], + strict=False, + ) ) ] if "rollout_routed_experts" in rollout_data: diff --git a/miles/backends/megatron_utils/cp_utils.py b/miles/backends/megatron_utils/cp_utils.py index 92baa6954..2e795d3d3 100644 --- a/miles/backends/megatron_utils/cp_utils.py +++ b/miles/backends/megatron_utils/cp_utils.py @@ -9,6 +9,8 @@ def get_logits_and_tokens_offset_with_cp( total_length: int, response_length: int, + qkv_format: str = "thd", + max_seq_len: int | None = None, ): """ All offsets start from the begining of the prompt. @@ -18,7 +20,11 @@ def get_logits_and_tokens_offset_with_cp( assert cp_size > 1 prompt_length = total_length - response_length - chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) + if qkv_format == "thd": + chunk_size = (total_length + 2 * cp_size - 1) // (2 * cp_size) + else: + assert max_seq_len is not None, "max_seq_len must be provided for qkv_format=bshd" + chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) # the offset of 2 chunks chunk_0 = (cp_rank * chunk_size, (cp_rank + 1) * chunk_size) @@ -49,6 +55,8 @@ def get_sum_of_sample_mean( response_lengths: list[int], loss_masks: list[torch.Tensor], calculate_per_token_loss: bool = False, + qkv_format: str = "thd", + max_seq_lens: list[int] | None = None, ) -> Callable[[torch.Tensor], torch.Tensor]: """ Calculate correct sample mean for CP @@ -78,8 +86,11 @@ def sum_of_token(x: torch.Tensor) -> torch.Tensor: for i, (total_length, response_length, loss_mask) in enumerate( zip(total_lengths, response_lengths, loss_masks, strict=False) ): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None prompt_length = total_length - response_length - _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( + total_length, response_length, qkv_format, max_seq_len + ) loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0)) @@ -160,23 +171,44 @@ def zero(len: int) -> torch.Tensor: return full_tensor -def slice_with_cp(tokens: torch.Tensor, pad_value: tuple[int, float, Callable]) -> torch.Tensor: +def slice_with_cp( + tokens: torch.Tensor, + pad_value: tuple[int, float, Callable], + qkv_format: str = "thd", + max_seq_len: int | None = None, +) -> torch.Tensor: cp_rank = mpu.get_context_parallel_rank() cp_size = mpu.get_context_parallel_world_size() + if qkv_format == "bshd": + assert max_seq_len is not None + + def pad_tokens(tokens, pad): + if isinstance(pad_value, Callable): + pad_func = pad_value + tokens = pad_func(tokens, pad) + else: + # pad on the first dimension + pad_tuple = (0, 0) * (tokens.dim() - 1) + (0, pad) + tokens = F.pad(tokens, pad_tuple, value=pad_value) + return tokens + if cp_size == 1: + if qkv_format == "bshd": + pad = max_seq_len - tokens.size(0) + tokens = pad_tokens(tokens, pad) return tokens - # pad - chunk_size = (len(tokens) + 2 * cp_size - 1) // (2 * cp_size) - pad = 2 * cp_size * chunk_size - len(tokens) - if isinstance(pad_value, Callable): - pad_func = pad_value - tokens = pad_func(tokens, pad) + token_len = len(tokens) + if qkv_format == "thd": + chunk_size = (token_len + 2 * cp_size - 1) // (2 * cp_size) else: - # pad on the first dimension - pad_tuple = (0, 0) * (tokens.dim() - 1) + (0, pad) - tokens = F.pad(tokens, pad_tuple, value=pad_value) + chunk_size = (max_seq_len + 2 * cp_size - 1) // (2 * cp_size) + + # pad + pad = 2 * cp_size * chunk_size - token_len + tokens = pad_tokens(tokens, pad) + # get 2 chunk for thd cp start_1, end_1 = chunk_size * cp_rank, chunk_size * (cp_rank + 1) start_2, end_2 = chunk_size * (2 * cp_size - cp_rank - 1), chunk_size * (2 * cp_size - cp_rank) @@ -187,6 +219,8 @@ def slice_log_prob_with_cp( log_prob: list[float] | torch.Tensor, total_length: int, response_length: int, + qkv_format: str = "thd", + max_token_len: int | None = None, ) -> list[float] | torch.Tensor: assert len(log_prob) == response_length @@ -196,7 +230,9 @@ def slice_log_prob_with_cp( return log_prob prompt_length = total_length - response_length - _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp(total_length, response_length) + _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( + total_length, response_length, qkv_format, max_token_len + ) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] chunk_2 = log_prob[logits_offset[1][0] - (prompt_length - 1) : logits_offset[1][1] - (prompt_length - 1)] diff --git a/miles/backends/megatron_utils/data.py b/miles/backends/megatron_utils/data.py index 73300bed6..f94d1b7e0 100644 --- a/miles/backends/megatron_utils/data.py +++ b/miles/backends/megatron_utils/data.py @@ -23,9 +23,7 @@ def get_batch( - data_iterator: "DataIterator", - keys: Sequence[str], - pad_multiplier: int = 128, + data_iterator: "DataIterator", keys: Sequence[str], pad_multiplier: int = 128, qkv_format: str = "thd" ) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]: """ Generate a CP-ready micro-batch with packed sequence parameters. @@ -55,39 +53,50 @@ def get_batch( tokens = batch["tokens"] # use 0 as the pad token id should be fine? pad_token_id = 0 + pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier # for cp, we need all tokens to calculate logprob batch["unconcat_tokens"] = tokens cp_size = mpu.get_context_parallel_world_size() - tokens = [slice_with_cp(t, pad_token_id) for t in tokens] - - cu_seqlens = [0] - for t in tokens: - cu_seqlens.append(cu_seqlens[-1] + t.size(0)) - tokens = torch.cat(tokens) + if qkv_format == "bshd": + max_seqlen = batch["max_seq_lens"][0] + assert max([t.size(0) for t in tokens]) <= max_seqlen + tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens] + tokens = torch.stack(tokens) + + elif qkv_format == "thd": + tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens] + + cu_seqlens = [0] + for t in tokens: + cu_seqlens.append(cu_seqlens[-1] + t.size(0)) + + tokens = torch.cat(tokens) + + # Always pad to reduce memory fragmentation and maybe make the computation faster + pad = (pad_size - tokens.size(0) % pad_size) % pad_size + if pad != 0: + tokens = F.pad(tokens, (0, pad), value=pad_token_id) + cu_seqlens.append(cu_seqlens[-1] + pad) + + # thd requires the cu_seqlens to be of the origin length + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format="thd", + ) - # Always pad to reduce memory fragmentation and maybe make the computation faster - pad_size = mpu.get_tensor_model_parallel_world_size() * pad_multiplier - pad = (pad_size - tokens.size(0) % pad_size) % pad_size - if pad != 0: - tokens = F.pad(tokens, (0, pad), value=pad_token_id) - cu_seqlens.append(cu_seqlens[-1] + pad) - - # thd requires the cu_seqlens to be of the origin length - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - - packed_seq_params = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_kv=max_seqlen, - qkv_format="thd", - ) + tokens = tokens.unsqueeze(0) + else: + raise ValueError(f"Unsupported qkv_format: {qkv_format}") - tokens = tokens.unsqueeze(0) batch["tokens"] = tokens batch["packed_seq_params"] = packed_seq_params return batch @@ -311,6 +320,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc response_lengths = rollout_data["response_lengths"] loss_masks = rollout_data["loss_masks"] total_lengths = rollout_data["total_lengths"] + max_seq_lens = rollout_data.get("max_seq_lens", None) for key, val in rollout_data.items(): if key in [ @@ -318,6 +328,7 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc "loss_masks", "sample_indices", "rollout_routed_experts", + "max_seq_lens", ]: continue # Upload per sample mean for each rollout value @@ -329,7 +340,13 @@ def log_rollout_data(rollout_id: int, args: Namespace, rollout_data: RolloutBatc # modified in place and will cause problem for the next rollout. val = torch.cat(val).clone().detach() if key in ["log_probs", "ref_log_probs", "rollout_log_probs", "returns", "advantages", "values"]: - sum_of_sample_mean = get_sum_of_sample_mean(total_lengths, response_lengths, loss_masks) + sum_of_sample_mean = get_sum_of_sample_mean( + total_lengths, + response_lengths, + loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) val = cp_size * sum_of_sample_mean(val) / len(loss_masks) else: val = val.mean() * cp_size diff --git a/miles/backends/megatron_utils/loss.py b/miles/backends/megatron_utils/loss.py index 52ae3a9f4..d7b72a512 100644 --- a/miles/backends/megatron_utils/loss.py +++ b/miles/backends/megatron_utils/loss.py @@ -31,6 +31,7 @@ def get_responses( unconcat_tokens: list[torch.Tensor], total_lengths: list[int], response_lengths: list[int], + max_seq_lens: list[int] | None = None, ) -> Iterator[tuple[torch.Tensor, torch.Tensor]]: """Yield response-aligned `(logits_chunk, tokens_chunk)` pairs per sample. @@ -53,24 +54,40 @@ def get_responses( `[R, V]` (policy) or `[R, 1]` (value) and `tokens_chunk` is shape `[R]` (1D int64), both aligned to response tokens for one sample. """ - assert logits.size(0) == 1, f"{logits.shape}" + qkv_format = args.qkv_format + assert logits.dtype == torch.float32, f"{logits.dtype}" + assert len(logits.shape) == 3, f"{logits.shape}" + + if qkv_format == "thd": + assert logits.size(0) == 1, f"{logits.shape}" + logits = logits.squeeze(0) + else: + assert max_seq_lens is not None + logits = logits.view(-1, logits.size(-1)) - logits = logits.squeeze(0) logits = logits.div(args.rollout_temperature) cp_size = mpu.get_context_parallel_world_size() end = 0 - for tokens, total_length, response_length in zip(unconcat_tokens, total_lengths, response_lengths, strict=False): + for i, (tokens, total_length, response_length) in enumerate( + zip(unconcat_tokens, total_lengths, response_lengths, strict=False) + ): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + if cp_size == 1: - end += total_length - start = end - response_length + if qkv_format == "bshd": + end = max_seq_len * i + total_length + start = end - response_length + else: + end += total_length + start = end - response_length logits_chunk = logits[start - 1 : end - 1] tokens_chunk = tokens[-response_length:] else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length + total_length, response_length, qkv_format, max_seq_len ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -100,6 +117,7 @@ def get_log_probs_and_entropy( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, + max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: """Compute per-token log-probabilities (and optionally entropy) on responses. @@ -132,6 +150,7 @@ def get_log_probs_and_entropy( unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, + max_seq_lens=max_seq_lens, ): log_prob, entropy = calculate_log_probs_and_entropy( logits_chunk, tokens_chunk, mpu.get_tensor_model_parallel_group(), with_entropy=with_entropy @@ -157,6 +176,7 @@ def get_values( response_lengths: list[int], with_entropy: bool = False, non_loss_data: bool = True, + max_seq_lens: list[int] | None = None, ) -> dict[str, list[torch.Tensor]]: """Extract per-token value predictions over response tokens. @@ -184,6 +204,7 @@ def get_values( unconcat_tokens=unconcat_tokens, total_lengths=total_lengths, response_lengths=response_lengths, + max_seq_lens=max_seq_lens, ): assert logits_chunk.size(-1) == 1, f"{logits_chunk.shape}" value_list.append(logits_chunk.squeeze(-1)) @@ -221,6 +242,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) response_lengths: list[int] = rollout_data.get("response_lengths") loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") total_lengths: list[int] = rollout_data.get("total_lengths") + max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) # return when not the last pp stage. if log_probs is None and values is None: @@ -314,8 +336,11 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) total_len = total_lengths[i] response_len = response_lengths[i] prompt_len = total_len - response_len + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp(total_len, response_len) + _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( + total_len, response_len, args.qkv_format, max_seq_len + ) # Convert global offsets to response-space offsets s0, e0 = token_offsets[0] @@ -444,6 +469,7 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] + max_seq_lens = batch.get("max_seq_lens", None) log_probs_and_entropy = get_log_probs_and_entropy( logits, @@ -452,6 +478,7 @@ def policy_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=True, + max_seq_lens=max_seq_lens, ) log_probs = log_probs_and_entropy["log_probs"] @@ -529,7 +556,12 @@ def policy_loss_function( # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction # modified_response_masks will be sliced with cp in get_sum_of_sample_mean sum_of_sample_mean = get_sum_of_sample_mean( - total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss + total_lengths, + response_lengths, + modified_response_masks, + args.calculate_per_token_loss, + args.qkv_format, + batch.get("max_seq_lens", None), ) pg_loss = sum_of_sample_mean(pg_loss) @@ -626,6 +658,7 @@ def value_loss_function( unconcat_tokens=batch["unconcat_tokens"], total_lengths=batch["total_lengths"], response_lengths=batch["response_lengths"], + max_seq_lens=batch.get("max_seq_lens", None), ) values = torch.cat([value.flatten() for value in values["values"]], dim=0) @@ -684,6 +717,7 @@ def sft_loss_function( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=False, + max_seq_lens=batch.get("max_seq_lens", None), ) log_probs = log_probs_and_entropy["log_probs"] @@ -739,6 +773,8 @@ def loss_function( batch["response_lengths"], batch["loss_masks"], args.calculate_per_token_loss, + args.qkv_format, + batch.get("max_seq_lens", None), ) match args.loss_type: diff --git a/miles/backends/megatron_utils/model.py b/miles/backends/megatron_utils/model.py index 7e8c8770d..c4e182797 100644 --- a/miles/backends/megatron_utils/model.py +++ b/miles/backends/megatron_utils/model.py @@ -210,7 +210,10 @@ def forward_step( # Get the batch. batch = get_batch( - data_iterator, ["tokens", "total_lengths", "response_lengths"], args.data_pad_size_multiplier + data_iterator, + ["tokens", "total_lengths", "response_lengths", "max_seq_lens"], + args.data_pad_size_multiplier, + args.qkv_format, ) unconcat_tokens = batch["unconcat_tokens"] tokens = batch["tokens"] @@ -232,6 +235,7 @@ def forward_step( total_lengths=total_lengths, response_lengths=response_lengths, with_entropy=args.use_rollout_entropy, + max_seq_lens=batch.get("max_seq_lens", None), ) # Turn on evaluation mode which disables dropout. @@ -361,8 +365,10 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p "advantages", "returns", "rollout_log_probs", + "max_seq_lens", ], args.data_pad_size_multiplier, + args.qkv_format, ) if os.environ.get("ENABLE_ROUTING_REPLAY", "0") == "1": diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b2d5cc13f..ce6e47161 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -118,6 +118,13 @@ def add_train_arguments(parser): default="megatron", help="The backend for training.", ) + parser.add_argument( + "--qkv-format", + type=str, + choices=["thd", "bshd"], + default="thd", + help="The qkv layout for Megatron backend.", + ) parser.add_argument( "--true-on-policy-mode", action="store_true", @@ -1585,6 +1592,16 @@ def miles_validate_args(args): if args.prefill_num_servers is not None: assert not args.use_fault_tolerance, "fault tolerance is not supported when prefill_num_servers is set." + assert args.qkv_format in [ + "thd", + "bshd", + ], f"qkv_format {args.qkv_format} is not supported. (only 'thd' and 'bshd' are supported)" + if args.qkv_format == "bshd": + assert args.train_backend == "megatron", "bshd format is only supported for megatron backend." + assert ( + args.use_dynamic_batch_size is False + ), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead." + def hf_validate_args(args, hf_config): def equal(x, y): From b777306edbc8e2006f08d2533ee120cc59fdb9be Mon Sep 17 00:00:00 2001 From: "lg(x)" <70553669+GuanxingLu@users.noreply.github.com> Date: Thu, 25 Dec 2025 01:55:02 -0500 Subject: [PATCH 11/21] Add LoRA for FSDP backend. (#307) (#326) Co-authored-by: PopSoda2002 --- miles/backends/fsdp_utils/actor.py | 11 +- miles/backends/fsdp_utils/arguments.py | 7 ++ miles/backends/fsdp_utils/checkpoint.py | 47 ++++++-- miles/backends/fsdp_utils/lora_utils.py | 77 ++++++++++++ .../fsdp_utils/update_weight_utils.py | 112 +++++++++++++----- miles/backends/sglang_utils/sglang_engine.py | 26 +++- miles/ray/placement_group.py | 1 + miles/ray/rollout.py | 4 +- miles/rollout/sglang_rollout.py | 5 + miles/utils/arguments.py | 30 +++++ requirements.txt | 1 + train.py | 13 +- 12 files changed, 284 insertions(+), 50 deletions(-) create mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..2cfbc4339 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -28,6 +28,7 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -94,6 +95,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) + model.train() full_state = model.state_dict() @@ -107,11 +111,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Avoid "does not require grad" error + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index a319fe6e5..7cd10867b 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,6 +60,13 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None + # LoRA configuration + lora_rank: int = 0 + lora_alpha: int = 16 + target_modules: str = "all-linear" + exclude_modules: str | None = None + lora_adapter_path: str | None = None + def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3c49a10f8..8508fba2b 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,21 +12,34 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.fsdp_utils.lora_utils import is_lora_model + logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, lora_only: bool = False): self.model = model + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" + + lora_only = lora_dir.exists() and is_lora_model(actor.model) + model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") return None - # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..d6483b372 --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,77 @@ +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_READY_MARKER = ".lora_ready" +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk with file lock mechanism.""" + # TODO: All gather lora layers not full layers + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=lora_state_dict) + + # TODO: check if file lock is needed or better way to do it + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk.""" + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index c8dcbd810..6e4ee73a5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,5 +1,6 @@ import abc import logging +import os import socket from argparse import Namespace from collections.abc import Sequence @@ -25,6 +26,7 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -33,6 +35,9 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model + self.weight_version = 0 + self._lora_loaded = False + self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -43,38 +48,85 @@ def connect_rollout_engines( pass def update_weights(self) -> None: - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket + self.weight_version += 1 + + # Update base model if needed + # Level 1: only sync base once for LoRA models, then just LoRA + # Level 2: always sync base + LoRA + if not (is_lora_model(self.model) and self._base_synced and self.args.offload_rollout_level == 1): bucket = [] bucket_size = 0 + for name, param in self.model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + + self._base_synced = True + + # Update lora weights if needed + if is_lora_model(self.model): + self._update_lora_via_file() + + def _update_lora_via_file(self) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + if os.path.exists(self._lora_save_dir): + delete_lora_from_disk(self._lora_save_dir) + + dist.barrier() + + save_lora_to_disk(self.model, self._lora_save_dir) + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket) + self.update_bucket_weights(bucket, weight_version=self.weight_version) @abc.abstractmethod - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: pass @@ -114,7 +166,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -162,6 +214,7 @@ def update_bucket_weights(self, named_tensors) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, + "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -174,10 +227,6 @@ def update_bucket_weights(self, named_tensors) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: - self.args = args - self.model = model - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -220,7 +269,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -235,6 +284,7 @@ def update_bucket_weights(self, named_tensors) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, + weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 2e1afe625..c9a774b32 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -278,9 +278,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage resume: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -336,6 +342,18 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -419,6 +437,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index b6fb7a20b..7bd842960 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -177,6 +177,7 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: + # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 9ee0fbb8a..83c3a5519 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -127,8 +127,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + def offload(self, tags: list[str] = None): + return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2e33542a5..36f6e7ce0 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -11,6 +11,7 @@ from packaging.version import parse from tqdm import tqdm +from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.utils.async_utils import run @@ -124,6 +125,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + # Use LoRA adapter when LoRA is enabled + if args.lora_rank > 0 or args.lora_adapter_path is not None: + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..b1c425550 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,6 +104,15 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=int, + default=2, + help=( + "The offload level for rollout when offload-rollout is set. " + "1 means only offload kv cache, 2 means offload kv cache and weights." + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1415,6 +1424,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + # assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 2c20195fc..3840f294d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf +peft pillow pylatexenc pyyaml diff --git a/train.py b/train.py index 9fb480eda..a43aa61db 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout: + if args.offload_rollout and args.offload_rollout_level == 2: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,7 +68,16 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + # level 1: offload kv cache only, level 2: offload weights + kv cache + ray.get( + rollout_manager.offload.remote( + tags=( + [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] + if args.offload_rollout_level == 1 + else [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH] + ) + ) + ) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 54d57175a27d60e29f20813a21aa127859a3acfd Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Thu, 25 Dec 2025 00:06:05 -0800 Subject: [PATCH 12/21] Revert "Add LoRA for FSDP backend. (#307)" [will merge on later] (#351) --- miles/backends/fsdp_utils/actor.py | 11 +- miles/backends/fsdp_utils/arguments.py | 7 -- miles/backends/fsdp_utils/checkpoint.py | 47 ++------ miles/backends/fsdp_utils/lora_utils.py | 77 ------------ .../fsdp_utils/update_weight_utils.py | 112 +++++------------- miles/backends/sglang_utils/sglang_engine.py | 26 +--- miles/ray/placement_group.py | 1 - miles/ray/rollout.py | 4 +- miles/rollout/sglang_rollout.py | 5 - miles/utils/arguments.py | 30 ----- requirements.txt | 1 - train.py | 13 +- 12 files changed, 50 insertions(+), 284 deletions(-) delete mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 2cfbc4339..1e3e5b3ae 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -28,7 +28,6 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences -from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -95,9 +94,6 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) - if self.args.lora_rank > 0 or self.args.lora_adapter_path: - model = apply_lora_to_model(model, self.args) - model.train() full_state = model.state_dict() @@ -111,14 +107,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - # Avoid "does not require grad" error - gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) + self.model.gradient_checkpointing_enable() if args.optimizer == "adam": - trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - trainable_params, + self.model.parameters(), lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index 7cd10867b..a319fe6e5 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,13 +60,6 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None - # LoRA configuration - lora_rank: int = 0 - lora_alpha: int = 16 - target_modules: str = "all-linear" - exclude_modules: str | None = None - lora_adapter_path: str | None = None - def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 8508fba2b..3c49a10f8 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,34 +12,21 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful -from miles.backends.fsdp_utils.lora_utils import is_lora_model - logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model, lora_only: bool = False): + def __init__(self, model): self.model = model - self.lora_only = lora_only - self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - if self.lora_only: - model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} - return {self._key: model_state_dict} + return {"model": model_state_dict} def load_state_dict(self, state_dict): - data = state_dict[self._key] - - if self.lora_only: - full_state_dict, _ = get_state_dict(self.model, optimizers=[]) - full_state_dict.update(data) - set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) - else: - set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) + set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) class OptimizerState(Stateful): @@ -116,22 +103,20 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" - lora_dir = checkpoint_dir / "adapter" - - lora_only = lora_dir.exists() and is_lora_model(actor.model) - model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") + logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") return None - model_state = ModelState(actor.model, lora_only=lora_only) + # Load model weights (always) + model_state = ModelState(actor.model) state_dict = {"model_state": model_state} + try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") + logger.info(f"[FSDP] Loaded model from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -225,19 +210,9 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - lora_only = is_lora_model(actor.model) - if lora_only: - save_dir = checkpoint_dir / "adapter" - if dist.get_rank() == 0: - save_dir.mkdir(parents=True, exist_ok=True) - dist.barrier() - else: - save_dir = model_dir - - model_state = ModelState(actor.model, lora_only=lora_only) + model_state = ModelState(actor.model) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(save_dir)) - logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") + dcp.save(state_dict, checkpoint_id=str(model_dir)) # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py deleted file mode 100644 index d6483b372..000000000 --- a/miles/backends/fsdp_utils/lora_utils.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import os -import shutil -from pathlib import Path - -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - -try: - from peft import LoraConfig, PeftModel, TaskType, get_peft_model -except ImportError as err: - raise ImportError("peft library required for LoRA. Install with: pip install peft") from err - -logger = logging.getLogger(__name__) - -LORA_READY_MARKER = ".lora_ready" -LORA_ADAPTER_NAME = "miles_lora" -LORA_SUBDIR = "tmp_lora" - - -def apply_lora_to_model(model: nn.Module, args) -> nn.Module: - if args.lora_adapter_path: - logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") - model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) - peft_config = model.peft_config["default"] - if isinstance(peft_config.task_type, str): - peft_config.task_type = TaskType.CAUSAL_LM - model.print_trainable_parameters() - return model - - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - target_modules=args.target_modules, - bias="none", - ) - - model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) - model.print_trainable_parameters() - logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") - return model - - -def is_lora_model(module: nn.Module) -> bool: - unwrapped = getattr(module, "_fsdp_wrapped_module", module) - return hasattr(unwrapped, "peft_config") - - -def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: - """Save LoRA adapter to disk with file lock mechanism.""" - # TODO: All gather lora layers not full layers - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - full_state_dict = get_model_state_dict(module, options=options) - - lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} - - if dist.get_rank() == 0: - save_path = Path(save_dir) - save_path.mkdir(parents=True, exist_ok=True) - - module.save_pretrained(str(save_path), state_dict=lora_state_dict) - - # TODO: check if file lock is needed or better way to do it - os.sync() - - logger.info(f"Saved LoRA adapter to {save_path}") - return save_dir - - -def delete_lora_from_disk(save_dir: str) -> None: - """Delete LoRA adapter files from disk.""" - save_path = Path(save_dir) - if save_path.exists(): - shutil.rmtree(save_path) - logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index 6e4ee73a5..c8dcbd810 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,6 +1,5 @@ import abc import logging -import os import socket from argparse import Namespace from collections.abc import Sequence @@ -26,7 +25,6 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] -from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -35,9 +33,6 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model - self.weight_version = 0 - self._lora_loaded = False - self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -48,85 +43,38 @@ def connect_rollout_engines( pass def update_weights(self) -> None: - self.weight_version += 1 - - # Update base model if needed - # Level 1: only sync base once for LoRA models, then just LoRA - # Level 2: always sync base + LoRA - if not (is_lora_model(self.model) and self._base_synced and self.args.offload_rollout_level == 1): - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("base_model.model.", "").replace(".base_layer", "") - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: + bucket = [] + bucket_size = 0 + for name, param in self.model.state_dict().items(): + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: self.wait_and_update_bucket_weights(bucket) del bucket - - self._base_synced = True - - # Update lora weights if needed - if is_lora_model(self.model): - self._update_lora_via_file() - - def _update_lora_via_file(self) -> None: - """Push LoRA weights to rollout engines using disk files.""" - self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) - if dist.get_rank() == 0: - if os.path.exists(self._lora_save_dir): - delete_lora_from_disk(self._lora_save_dir) - - dist.barrier() - - save_lora_to_disk(self.model, self._lora_save_dir) - - dist.barrier() - - if dist.get_rank() == 0: - if self._lora_loaded: - refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] - ray.get(refs) - - refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - ray.get(refs) - - refs = [ - engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) - for engine in self.rollout_engines - ] - ray.get(refs) - - refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - ray.get(refs) - - self._lora_loaded = True - - dist.barrier() + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket, weight_version=self.weight_version) + self.update_bucket_weights(bucket) @abc.abstractmethod - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: pass @@ -166,7 +114,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -214,7 +162,6 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, - "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -227,6 +174,10 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" + def __init__(self, args: Namespace, model: torch.nn.Module) -> None: + self.args = args + self.model = model + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -269,7 +220,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -284,7 +235,6 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, - weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index c9a774b32..2e1afe625 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -278,15 +278,9 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self, tags: list[str] = None): - """ - Available tags for multi-stage resume: weights, kv_cache - """ + def release_memory_occupation(self): self.flush_cache() - return self._make_request( - "release_memory_occupation", - {"tags": tags}, - ) + return self._make_request("release_memory_occupation") def resume_memory_occupation(self, tags: list[str] = None): """ @@ -342,18 +336,6 @@ def update_weights_from_distributed( payload, ) - def load_lora_adapter(self, lora_name: str, lora_path: str): - return self._make_request( - "load_lora_adapter", - {"lora_name": lora_name, "lora_path": lora_path}, - ) - - def unload_lora_adapter(self, lora_name: str): - return self._make_request( - "unload_lora_adapter", - {"lora_name": lora_name}, - ) - def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -437,10 +419,6 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" - if args.lora_rank > 0 or args.lora_adapter_path is not None: - kwargs["enable_lora"] = True - kwargs["max_lora_rank"] = args.lora_rank - kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index 7bd842960..b6fb7a20b 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -177,7 +177,6 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: - # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 83c3a5519..9ee0fbb8a 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -127,8 +127,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self, tags: list[str] = None): - return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) + def offload(self): + return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 36f6e7ce0..2e33542a5 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -11,7 +11,6 @@ from packaging.version import parse from tqdm import tqdm -from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.utils.async_utils import run @@ -125,10 +124,6 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } - # Use LoRA adapter when LoRA is enabled - if args.lora_rank > 0 or args.lora_adapter_path is not None: - payload["lora_path"] = LORA_ADAPTER_NAME - if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index b1c425550..ce6e47161 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,15 +104,6 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) - parser.add_argument( - "--offload-rollout-level", - type=int, - default=2, - help=( - "The offload level for rollout when offload-rollout is set. " - "1 means only offload kv cache, 2 means offload kv cache and weights." - ), - ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1424,27 +1415,6 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." - if args.lora_rank > 0: - # assert args.save is not None, "'--save' is required when LoRA is enabled." - assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." - - if args.target_modules == "all-linear": - modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] - elif "," in args.target_modules: - modules = [m.strip() for m in args.target_modules.split(",")] - else: - modules = [args.target_modules] - - if args.exclude_modules: - exclude_set = ( - set(m.strip() for m in args.exclude_modules.split(",")) - if "," in args.exclude_modules - else {args.exclude_modules} - ) - modules = [m for m in modules if m not in exclude_set] - - args.target_modules = modules - assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 3840f294d..2c20195fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf -peft pillow pylatexenc pyyaml diff --git a/train.py b/train.py index a43aa61db..9fb480eda 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout and args.offload_rollout_level == 2: + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,16 +68,7 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - # level 1: offload kv cache only, level 2: offload weights + kv cache - ray.get( - rollout_manager.offload.remote( - tags=( - [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_CUDA_GRAPH] - if args.offload_rollout_level == 1 - else [GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_CUDA_GRAPH] - ) - ) - ) + ray.get(rollout_manager.offload.remote()) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 81a07116eefbc8fba450a57028536cdb953abef9 Mon Sep 17 00:00:00 2001 From: "lg(x)" <70553669+GuanxingLu@users.noreply.github.com> Date: Sun, 28 Dec 2025 10:19:20 -0800 Subject: [PATCH 13/21] [Fet] Lora FSDP RL training - #326 and add CI/CD tests (#351) (#352) Co-authored-by: PopSoda2002 --- .github/workflows/pr-test.yml | 6 +- .github/workflows/pr-test.yml.j2 | 4 + miles/backends/fsdp_utils/actor.py | 11 +- miles/backends/fsdp_utils/arguments.py | 7 ++ miles/backends/fsdp_utils/checkpoint.py | 47 ++++++-- miles/backends/fsdp_utils/lora_utils.py | 76 ++++++++++++ .../fsdp_utils/update_weight_utils.py | 110 +++++++++++++----- miles/backends/sglang_utils/sglang_engine.py | 26 ++++- miles/ray/placement_group.py | 1 + miles/ray/rollout.py | 6 +- miles/rollout/sglang_rollout.py | 5 + miles/utils/arguments.py | 32 +++++ requirements.txt | 1 + tests/test_external_rollout.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 2 +- tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 25 +++- tests/test_qwen3_0.6B_fsdp_distributed.py | 20 +++- train.py | 9 +- 19 files changed, 330 insertions(+), 62 deletions(-) create mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index e649da717..d662dbaf1 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -63,6 +63,8 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-long: @@ -84,7 +86,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -103,4 +105,6 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 06d6ed570..56e24b23f 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -13,6 +13,8 @@ {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, + {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'}, + {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'}, ], }, } %> @@ -77,5 +79,7 @@ jobs: - name: Execute shell: bash + env: + ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 1e3e5b3ae..2cfbc4339 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -28,6 +28,7 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences +from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -94,6 +95,9 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) + if self.args.lora_rank > 0 or self.args.lora_adapter_path: + model = apply_lora_to_model(model, self.args) + model.train() full_state = model.state_dict() @@ -107,11 +111,14 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + # Avoid "does not require grad" error + gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) if args.optimizer == "adam": + trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - self.model.parameters(), + trainable_params, lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index a319fe6e5..7cd10867b 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,6 +60,13 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None + # LoRA configuration + lora_rank: int = 0 + lora_alpha: int = 16 + target_modules: str = "all-linear" + exclude_modules: str | None = None + lora_adapter_path: str | None = None + def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 3c49a10f8..8508fba2b 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,21 +12,34 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful +from miles.backends.fsdp_utils.lora_utils import is_lora_model + logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model): + def __init__(self, model, lora_only: bool = False): self.model = model + self.lora_only = lora_only + self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - return {"model": model_state_dict} + if self.lora_only: + model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} + return {self._key: model_state_dict} def load_state_dict(self, state_dict): - set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) + data = state_dict[self._key] + + if self.lora_only: + full_state_dict, _ = get_state_dict(self.model, optimizers=[]) + full_state_dict.update(data) + set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) + else: + set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) class OptimizerState(Stateful): @@ -103,20 +116,22 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" + lora_dir = checkpoint_dir / "adapter" + + lora_only = lora_dir.exists() and is_lora_model(actor.model) + model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") + logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") return None - # Load model weights (always) - model_state = ModelState(actor.model) + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded model from {model_dir}") + logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -210,9 +225,19 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - model_state = ModelState(actor.model) + lora_only = is_lora_model(actor.model) + if lora_only: + save_dir = checkpoint_dir / "adapter" + if dist.get_rank() == 0: + save_dir.mkdir(parents=True, exist_ok=True) + dist.barrier() + else: + save_dir = model_dir + + model_state = ModelState(actor.model, lora_only=lora_only) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(model_dir)) + dcp.save(state_dict, checkpoint_id=str(save_dir)) + logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py new file mode 100644 index 000000000..f7f85d84b --- /dev/null +++ b/miles/backends/fsdp_utils/lora_utils.py @@ -0,0 +1,76 @@ +import logging +import os +import shutil +from pathlib import Path + +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + +try: + from peft import LoraConfig, PeftModel, TaskType, get_peft_model +except ImportError as err: + raise ImportError("peft library required for LoRA. Install with: pip install peft") from err + +logger = logging.getLogger(__name__) + +LORA_ADAPTER_NAME = "miles_lora" +LORA_SUBDIR = "tmp_lora" + + +def apply_lora_to_model(model: nn.Module, args) -> nn.Module: + if args.lora_adapter_path: + logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") + model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) + peft_config = model.peft_config["default"] + if isinstance(peft_config.task_type, str): + peft_config.task_type = TaskType.CAUSAL_LM + model.print_trainable_parameters() + return model + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + bias="none", + ) + + model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) + model.print_trainable_parameters() + logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") + return model + + +def is_lora_model(module: nn.Module) -> bool: + unwrapped = getattr(module, "_fsdp_wrapped_module", module) + return hasattr(unwrapped, "peft_config") + + +def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: + """Save LoRA adapter to disk with file lock mechanism.""" + # TODO: All gather lora layers not full layers + options = StateDictOptions(full_state_dict=True, cpu_offload=True) + full_state_dict = get_model_state_dict(module, options=options) + + lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} + + if dist.get_rank() == 0: + save_path = Path(save_dir) + save_path.mkdir(parents=True, exist_ok=True) + + module.save_pretrained(str(save_path), state_dict=lora_state_dict) + + # TODO: check if file lock is needed or better way to do it + os.sync() + + logger.info(f"Saved LoRA adapter to {save_path}") + return save_dir + + +def delete_lora_from_disk(save_dir: str) -> None: + """Delete LoRA adapter files from disk.""" + save_path = Path(save_dir) + if save_path.exists(): + shutil.rmtree(save_path) + logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index c8dcbd810..f48cca7f5 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,5 +1,6 @@ import abc import logging +import os import socket from argparse import Namespace from collections.abc import Sequence @@ -25,6 +26,7 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] +from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -33,6 +35,9 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model + self.weight_version = 0 + self._lora_loaded = False + self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -43,38 +48,83 @@ def connect_rollout_engines( pass def update_weights(self) -> None: - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: - self.wait_and_update_bucket_weights(bucket) - del bucket + self.weight_version += 1 + + # Update base model if needed + if not (is_lora_model(self.model) and self._base_synced and "weight" not in self.args.offload_rollout_level): bucket = [] bucket_size = 0 + for name, param in self.model.state_dict().items(): + if any(x in name for x in ["_flat_param", "lora_"]): + continue + name = name.replace("base_model.model.", "").replace(".base_layer", "") + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + + self._base_synced = True + + # Update lora weights if needed + if is_lora_model(self.model): + self._update_lora_via_file() + + def _update_lora_via_file(self) -> None: + """Push LoRA weights to rollout engines using disk files.""" + self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) + if dist.get_rank() == 0: + if os.path.exists(self._lora_save_dir): + delete_lora_from_disk(self._lora_save_dir) + + dist.barrier() + + save_lora_to_disk(self.model, self._lora_save_dir) + + dist.barrier() + + if dist.get_rank() == 0: + if self._lora_loaded: + refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + refs = [ + engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) + for engine in self.rollout_engines + ] + ray.get(refs) + + refs = [engine.flush_cache.remote() for engine in self.rollout_engines] + ray.get(refs) + + self._lora_loaded = True + + dist.barrier() def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket) + self.update_bucket_weights(bucket, weight_version=self.weight_version) @abc.abstractmethod - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: pass @@ -114,7 +164,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -162,6 +212,7 @@ def update_bucket_weights(self, named_tensors) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, + "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -174,10 +225,6 @@ def update_bucket_weights(self, named_tensors) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" - def __init__(self, args: Namespace, model: torch.nn.Module) -> None: - self.args = args - self.model = model - def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -220,7 +267,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors) -> None: + def update_bucket_weights(self, named_tensors, weight_version=None) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -235,6 +282,7 @@ def update_bucket_weights(self, named_tensors) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, + weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index 2e1afe625..c9a774b32 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -278,9 +278,15 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self): + def release_memory_occupation(self, tags: list[str] = None): + """ + Available tags for multi-stage resume: weights, kv_cache + """ self.flush_cache() - return self._make_request("release_memory_occupation") + return self._make_request( + "release_memory_occupation", + {"tags": tags}, + ) def resume_memory_occupation(self, tags: list[str] = None): """ @@ -336,6 +342,18 @@ def update_weights_from_distributed( payload, ) + def load_lora_adapter(self, lora_name: str, lora_path: str): + return self._make_request( + "load_lora_adapter", + {"lora_name": lora_name, "lora_path": lora_path}, + ) + + def unload_lora_adapter(self, lora_name: str): + return self._make_request( + "unload_lora_adapter", + {"lora_name": lora_name}, + ) + def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -419,6 +437,10 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" + if args.lora_rank > 0 or args.lora_adapter_path is not None: + kwargs["enable_lora"] = True + kwargs["max_lora_rank"] = args.lora_rank + kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index b6fb7a20b..7bd842960 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -177,6 +177,7 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: + # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 9ee0fbb8a..cb82c8ce8 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -127,8 +127,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self): - return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) + def offload(self, tags: list[str] = None): + return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) @@ -412,7 +412,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_new_engines = len(rollout_engines) if num_new_engines == 0: - return num_new_engines, None + return num_new_engines if args.rollout_external: addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2e33542a5..36f6e7ce0 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -11,6 +11,7 @@ from packaging.version import parse from tqdm import tqdm +from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.utils.async_utils import run @@ -124,6 +125,10 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } + # Use LoRA adapter when LoRA is enabled + if args.lora_rank > 0 or args.lora_adapter_path is not None: + payload["lora_path"] = LORA_ADAPTER_NAME + if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..356b91a0e 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,6 +104,17 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) + parser.add_argument( + "--offload-rollout-level", + type=str, + nargs="+", + default=["kv_cache", "weight"], + help=( + "Specifies what to offload during rollout when offload-rollout is set. " + "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " + "Example: --offload-rollout-level kv_cache weight" + ), + ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1415,6 +1426,27 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." + if args.lora_rank > 0: + assert args.save is not None, "'--save' is required when LoRA is enabled." + assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." + + if args.target_modules == "all-linear": + modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] + elif "," in args.target_modules: + modules = [m.strip() for m in args.target_modules.split(",")] + else: + modules = [args.target_modules] + + if args.exclude_modules: + exclude_set = ( + set(m.strip() for m in args.exclude_modules.split(",")) + if "," in args.exclude_modules + else {args.exclude_modules} + ) + modules = [m for m in modules if m not in exclude_set] + + args.target_modules = modules + assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 2c20195fc..3840f294d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf +peft pillow pylatexenc pyyaml diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index c5c0838c5..f12837d88 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 6302aadb6..26d3e3197 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -11,7 +11,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 1c55ccb20..878d68b1c 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -10,7 +10,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 6967f9145..460f8a5a5 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,17 +1,29 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -50,7 +62,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -73,10 +85,17 @@ def execute(): "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step ) - misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " + misc_args = ( + "--actor-num-nodes 1 " + "--actor-num-gpus-per-node 2 " + "--colocate " + "--offload-rollout-level kv_cache weight " + "--train-backend fsdp " + ) train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index b3eb416b3..c4592ffdf 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,20 +1,30 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" - - +ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + lora_args = ( + ( + "--lora-rank 32 " + "--lora-alpha 32 " + "--target-modules all-linear " + f"--save /root/models/{MODEL_NAME}-lora-ckpt " + ) + if ENABLE_LORA + else "" + ) + rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -54,7 +64,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - "--lr 1e-6 " + f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -67,6 +77,7 @@ def execute(): "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " f"--rollout-num-gpus {1 if FEW_GPU else 2} " + "--offload-rollout-level kv_cache weight " "--train-backend fsdp " ) @@ -79,6 +90,7 @@ def execute(): train_args = ( f"{ckpt_args} " + f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/train.py b/train.py index 9fb480eda..212361fc0 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout: + if args.offload_rollout and "weight" in args.offload_rollout_level: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,7 +68,12 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - ray.get(rollout_manager.offload.remote()) + offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] + if "kv_cache" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) + if "weight" in args.offload_rollout_level: + offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) + ray.get(rollout_manager.offload.remote(tags=offload_tags)) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From d5e140d93a83ea9b8a396ded608cf7d094e1d511 Mon Sep 17 00:00:00 2001 From: Yuzhou Nie <62874089+rucnyz@users.noreply.github.com> Date: Sun, 28 Dec 2025 21:43:36 -0800 Subject: [PATCH 14/21] Fix lora_rank attribute check in arguments.py (#363) --- miles/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 356b91a0e..0e63f914e 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1426,7 +1426,7 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." - if args.lora_rank > 0: + if getattr(args, "lora_rank", 0) > 0: assert args.save is not None, "'--save' is required when LoRA is enabled." assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." From 3da31ef394c05f57f05739439d19e147b5fe0b28 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Sun, 28 Dec 2025 22:53:19 -0800 Subject: [PATCH 15/21] Revert "Fix lora_rank attribute check in arguments.py" (#369) --- miles/utils/arguments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 0e63f914e..356b91a0e 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1426,7 +1426,7 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." - if getattr(args, "lora_rank", 0) > 0: + if args.lora_rank > 0: assert args.save is not None, "'--save' is required when LoRA is enabled." assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." From ef57bde3632a4b02f6ef170decab7c51d1abc1a5 Mon Sep 17 00:00:00 2001 From: "Ethan (Yusheng) Su" Date: Sun, 28 Dec 2025 22:54:25 -0800 Subject: [PATCH 16/21] Revert "[Feat] Lora FSDP RL training - #326 and add CI/CD tests (#351)" (#370) --- .github/workflows/pr-test.yml | 6 +- .github/workflows/pr-test.yml.j2 | 4 - miles/backends/fsdp_utils/actor.py | 11 +- miles/backends/fsdp_utils/arguments.py | 7 -- miles/backends/fsdp_utils/checkpoint.py | 47 ++------ miles/backends/fsdp_utils/lora_utils.py | 76 ------------ .../fsdp_utils/update_weight_utils.py | 110 +++++------------- miles/backends/sglang_utils/sglang_engine.py | 26 +---- miles/ray/placement_group.py | 1 - miles/ray/rollout.py | 6 +- miles/rollout/sglang_rollout.py | 5 - miles/utils/arguments.py | 32 ----- requirements.txt | 1 - tests/test_external_rollout.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k.py | 2 +- tests/test_qwen2.5_0.5B_gsm8k_async.py | 2 +- tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py | 25 +--- tests/test_qwen3_0.6B_fsdp_distributed.py | 20 +--- train.py | 9 +- 19 files changed, 62 insertions(+), 330 deletions(-) delete mode 100644 miles/backends/fsdp_utils/lora_utils.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index d662dbaf1..e649da717 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -63,8 +63,6 @@ jobs: - name: Execute shell: bash - env: - ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} e2e-test-long: @@ -86,7 +84,7 @@ jobs: strategy: fail-fast: false matrix: - info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"enable_lora": "1", "num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] + info: [{"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k.py"}, {"num_gpus": 2, "test_file": "test_qwen2.5_0.5B_gsm8k_async.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_colocated_2xGPU.py"}, {"num_gpus": 2, "test_file": "test_qwen3_0.6B_fsdp_distributed.py"}] defaults: run: working-directory: ${{ github.workspace }} @@ -105,6 +103,4 @@ jobs: - name: Execute shell: bash - env: - ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} diff --git a/.github/workflows/pr-test.yml.j2 b/.github/workflows/pr-test.yml.j2 index 56e24b23f..06d6ed570 100644 --- a/.github/workflows/pr-test.yml.j2 +++ b/.github/workflows/pr-test.yml.j2 @@ -13,8 +13,6 @@ {'test_file': 'test_qwen2.5_0.5B_gsm8k_async.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2}, {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2}, - {'test_file': 'test_qwen3_0.6B_fsdp_colocated_2xGPU.py', 'num_gpus': 2, 'enable_lora': '1'}, - {'test_file': 'test_qwen3_0.6B_fsdp_distributed.py', 'num_gpus': 2, 'enable_lora': '1'}, ], }, } %> @@ -79,7 +77,5 @@ jobs: - name: Execute shell: bash - env: - ENABLE_LORA: ${{ matrix.info.enable_lora || '0' }} run: python tests/ci/gpu_lock_exec.py --count ${{ matrix.info.num_gpus }} -- python tests/${{ matrix.info.test_file }} <% endfor %> \ No newline at end of file diff --git a/miles/backends/fsdp_utils/actor.py b/miles/backends/fsdp_utils/actor.py index 2cfbc4339..1e3e5b3ae 100644 --- a/miles/backends/fsdp_utils/actor.py +++ b/miles/backends/fsdp_utils/actor.py @@ -28,7 +28,6 @@ from ...utils.profile_utils import TrainProfiler from . import checkpoint from .data_packing import pack_sequences, pad_packed_sequence_with_cp, unpack_sequences -from .lora_utils import apply_lora_to_model, is_lora_model from .lr_scheduler import get_lr_scheduler from .update_weight_utils import UpdateWeightFromDistributed, UpdateWeightFromTensor @@ -95,9 +94,6 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty attn_implementation=self.args.attn_implementation, ) - if self.args.lora_rank > 0 or self.args.lora_adapter_path: - model = apply_lora_to_model(model, self.args) - model.train() full_state = model.state_dict() @@ -111,14 +107,11 @@ def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # ty self.model = model if args.gradient_checkpointing: - # Avoid "does not require grad" error - gc_kwargs = {"use_reentrant": False} if is_lora_model(self.model) else {} - self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gc_kwargs) + self.model.gradient_checkpointing_enable() if args.optimizer == "adam": - trainable_params = [p for p in self.model.parameters() if p.requires_grad] self.optimizer = torch.optim.AdamW( - trainable_params, + self.model.parameters(), lr=args.lr, betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps, diff --git a/miles/backends/fsdp_utils/arguments.py b/miles/backends/fsdp_utils/arguments.py index 7cd10867b..a319fe6e5 100644 --- a/miles/backends/fsdp_utils/arguments.py +++ b/miles/backends/fsdp_utils/arguments.py @@ -60,13 +60,6 @@ class FSDPArgs: # YAML bookkeeping config: str | None = None - # LoRA configuration - lora_rank: int = 0 - lora_alpha: int = 16 - target_modules: str = "all-linear" - exclude_modules: str | None = None - lora_adapter_path: str | None = None - def parse_fsdp_cli(extra_args_provider=None): parser = argparse.ArgumentParser("FSDP SFT Training (miles)") diff --git a/miles/backends/fsdp_utils/checkpoint.py b/miles/backends/fsdp_utils/checkpoint.py index 8508fba2b..3c49a10f8 100644 --- a/miles/backends/fsdp_utils/checkpoint.py +++ b/miles/backends/fsdp_utils/checkpoint.py @@ -12,34 +12,21 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful -from miles.backends.fsdp_utils.lora_utils import is_lora_model - logger = logging.getLogger(__name__) class ModelState(Stateful): """Wrapper for model state only.""" - def __init__(self, model, lora_only: bool = False): + def __init__(self, model): self.model = model - self.lora_only = lora_only - self._key = "adapter" if lora_only else "model" def state_dict(self): model_state_dict, _ = get_state_dict(self.model, optimizers=[]) - if self.lora_only: - model_state_dict = {k: v for k, v in model_state_dict.items() if "lora_" in k} - return {self._key: model_state_dict} + return {"model": model_state_dict} def load_state_dict(self, state_dict): - data = state_dict[self._key] - - if self.lora_only: - full_state_dict, _ = get_state_dict(self.model, optimizers=[]) - full_state_dict.update(data) - set_state_dict(self.model, optimizers=[], model_state_dict=full_state_dict, optim_state_dict=None) - else: - set_state_dict(self.model, optimizers=[], model_state_dict=data, optim_state_dict=None) + set_state_dict(self.model, optimizers=[], model_state_dict=state_dict["model"], optim_state_dict=None) class OptimizerState(Stateful): @@ -116,22 +103,20 @@ def load(actor: Any) -> dict[str, Any] | None: model_dir = checkpoint_dir / "model" optimizer_dir = checkpoint_dir / "optimizer" lr_scheduler_dir = checkpoint_dir / "lr_scheduler" - lora_dir = checkpoint_dir / "adapter" - - lora_only = lora_dir.exists() and is_lora_model(actor.model) - model_dir = lora_dir if lora_only else model_dir if not model_dir.exists(): - logger.info(f"[FSDP] No model checkpoint found at {model_dir} or {lora_dir}; skipping load.") + logger.info(f"[FSDP] Model checkpoint {model_dir} not found; skipping load.") return None - model_state = ModelState(actor.model, lora_only=lora_only) + # Load model weights (always) + model_state = ModelState(actor.model) state_dict = {"model_state": model_state} + try: dcp.load(state_dict=state_dict, checkpoint_id=str(model_dir)) - logger.info(f"[FSDP] Loaded {'LoRA adapter' if lora_only else 'model'} from {model_dir}") + logger.info(f"[FSDP] Loaded model from {model_dir}") except Exception as e: - logger.error(f"[FSDP] Failed to load {'LoRA adapter' if lora_only else 'model'} from {model_dir}: {e}") + logger.error(f"[FSDP] Failed to load model from {model_dir}: {e}") return None # Load optimizer state (optional) @@ -225,19 +210,9 @@ def save(actor: Any, iteration: int) -> None: dist.barrier() # Save model weights - lora_only = is_lora_model(actor.model) - if lora_only: - save_dir = checkpoint_dir / "adapter" - if dist.get_rank() == 0: - save_dir.mkdir(parents=True, exist_ok=True) - dist.barrier() - else: - save_dir = model_dir - - model_state = ModelState(actor.model, lora_only=lora_only) + model_state = ModelState(actor.model) state_dict = {"model_state": model_state} - dcp.save(state_dict, checkpoint_id=str(save_dir)) - logger.info(f"[FSDP] Saved {'LoRA adapter' if lora_only else 'model'} to {save_dir}") + dcp.save(state_dict, checkpoint_id=str(model_dir)) # Save optimizer state if hasattr(actor, "optimizer") and actor.optimizer is not None: diff --git a/miles/backends/fsdp_utils/lora_utils.py b/miles/backends/fsdp_utils/lora_utils.py deleted file mode 100644 index f7f85d84b..000000000 --- a/miles/backends/fsdp_utils/lora_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -import logging -import os -import shutil -from pathlib import Path - -import torch.distributed as dist -import torch.nn as nn -from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict - -try: - from peft import LoraConfig, PeftModel, TaskType, get_peft_model -except ImportError as err: - raise ImportError("peft library required for LoRA. Install with: pip install peft") from err - -logger = logging.getLogger(__name__) - -LORA_ADAPTER_NAME = "miles_lora" -LORA_SUBDIR = "tmp_lora" - - -def apply_lora_to_model(model: nn.Module, args) -> nn.Module: - if args.lora_adapter_path: - logger.info(f"Loading LoRA adapter from {args.lora_adapter_path}") - model = PeftModel.from_pretrained(model, args.lora_adapter_path, is_trainable=True) - peft_config = model.peft_config["default"] - if isinstance(peft_config.task_type, str): - peft_config.task_type = TaskType.CAUSAL_LM - model.print_trainable_parameters() - return model - - lora_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, - r=args.lora_rank, - lora_alpha=args.lora_alpha, - target_modules=args.target_modules, - bias="none", - ) - - model = get_peft_model(model, lora_config) # autocast_adapter_dtype=False) - model.print_trainable_parameters() - logger.info(f"Applied LoRA: rank={args.lora_rank}, alpha={args.lora_alpha}") - return model - - -def is_lora_model(module: nn.Module) -> bool: - unwrapped = getattr(module, "_fsdp_wrapped_module", module) - return hasattr(unwrapped, "peft_config") - - -def save_lora_to_disk(module: nn.Module, save_dir: str) -> str: - """Save LoRA adapter to disk with file lock mechanism.""" - # TODO: All gather lora layers not full layers - options = StateDictOptions(full_state_dict=True, cpu_offload=True) - full_state_dict = get_model_state_dict(module, options=options) - - lora_state_dict = {name: param for name, param in full_state_dict.items() if "lora_" in name} - - if dist.get_rank() == 0: - save_path = Path(save_dir) - save_path.mkdir(parents=True, exist_ok=True) - - module.save_pretrained(str(save_path), state_dict=lora_state_dict) - - # TODO: check if file lock is needed or better way to do it - os.sync() - - logger.info(f"Saved LoRA adapter to {save_path}") - return save_dir - - -def delete_lora_from_disk(save_dir: str) -> None: - """Delete LoRA adapter files from disk.""" - save_path = Path(save_dir) - if save_path.exists(): - shutil.rmtree(save_path) - logger.info(f"Deleted LoRA adapter from {save_path}") diff --git a/miles/backends/fsdp_utils/update_weight_utils.py b/miles/backends/fsdp_utils/update_weight_utils.py index f48cca7f5..c8dcbd810 100644 --- a/miles/backends/fsdp_utils/update_weight_utils.py +++ b/miles/backends/fsdp_utils/update_weight_utils.py @@ -1,6 +1,5 @@ import abc import logging -import os import socket from argparse import Namespace from collections.abc import Sequence @@ -26,7 +25,6 @@ except ImportError: from sglang.srt.model_executor.model_runner import FlattenedTensorBucket # type: ignore[import] -from .lora_utils import LORA_ADAPTER_NAME, LORA_SUBDIR, delete_lora_from_disk, is_lora_model, save_lora_to_disk logger = logging.getLogger(__name__) @@ -35,9 +33,6 @@ class UpdateWeight(abc.ABC): def __init__(self, args: Namespace, model: torch.nn.Module) -> None: self.args = args self.model = model - self.weight_version = 0 - self._lora_loaded = False - self._base_synced = False @abc.abstractmethod def connect_rollout_engines( @@ -48,83 +43,38 @@ def connect_rollout_engines( pass def update_weights(self) -> None: - self.weight_version += 1 - - # Update base model if needed - if not (is_lora_model(self.model) and self._base_synced and "weight" not in self.args.offload_rollout_level): - bucket = [] - bucket_size = 0 - for name, param in self.model.state_dict().items(): - if any(x in name for x in ["_flat_param", "lora_"]): - continue - name = name.replace("base_model.model.", "").replace(".base_layer", "") - param_size = param.numel() * param.element_size() - if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: - self.wait_and_update_bucket_weights(bucket) - del bucket - bucket = [] - bucket_size = 0 - - param = param.cuda() - if isinstance(param, DTensor): - # async version of param.full_tensor - param = param.redistribute( - placements=[Replicate()] * param.device_mesh.ndim, - async_op=True, - ).to_local() - bucket.append((name, param)) - bucket_size += param_size - - if bucket: + bucket = [] + bucket_size = 0 + for name, param in self.model.state_dict().items(): + param_size = param.numel() * param.element_size() + if bucket and bucket_size + param_size >= self.args.update_weight_buffer_size: self.wait_and_update_bucket_weights(bucket) del bucket - - self._base_synced = True - - # Update lora weights if needed - if is_lora_model(self.model): - self._update_lora_via_file() - - def _update_lora_via_file(self) -> None: - """Push LoRA weights to rollout engines using disk files.""" - self._lora_save_dir = os.path.join(self.args.save, LORA_SUBDIR) - if dist.get_rank() == 0: - if os.path.exists(self._lora_save_dir): - delete_lora_from_disk(self._lora_save_dir) - - dist.barrier() - - save_lora_to_disk(self.model, self._lora_save_dir) - - dist.barrier() - - if dist.get_rank() == 0: - if self._lora_loaded: - refs = [engine.unload_lora_adapter.remote(LORA_ADAPTER_NAME) for engine in self.rollout_engines] - ray.get(refs) - - refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - ray.get(refs) - - refs = [ - engine.load_lora_adapter.remote(LORA_ADAPTER_NAME, self._lora_save_dir) - for engine in self.rollout_engines - ] - ray.get(refs) - - refs = [engine.flush_cache.remote() for engine in self.rollout_engines] - ray.get(refs) - - self._lora_loaded = True - - dist.barrier() + bucket = [] + bucket_size = 0 + + param = param.cuda() + if isinstance(param, DTensor): + # async version of param.full_tensor + param = param.redistribute( + placements=[Replicate()] * param.device_mesh.ndim, + async_op=True, + ).to_local() + bucket.append((name, param)) + bucket_size += param_size + + if bucket: + self.wait_and_update_bucket_weights(bucket) + del bucket + bucket = [] + bucket_size = 0 def wait_and_update_bucket_weights(self, bucket): bucket = [(name, param.wait()) if hasattr(param, "wait") else (name, param) for name, param in bucket] - self.update_bucket_weights(bucket, weight_version=self.weight_version) + self.update_bucket_weights(bucket) @abc.abstractmethod - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: pass @@ -164,7 +114,7 @@ def connect_rollout_engines( # Calculate TP rank within this SGLang engine group self.tp_rank = dist.get_rank() - start_rank - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: monkey_patch_torch_reductions() # Use flattened bucket approach similar to Megatron logger.info("Using flattened tensor bucket") @@ -212,7 +162,6 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches], "load_format": "flattened_bucket", "flush_cache": False, - "weight_version": str(weight_version), } ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs) ray.get(ref) @@ -225,6 +174,10 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: class UpdateWeightFromDistributed(UpdateWeight): """Broadcast weights via a temporary NCCL group to rollout engines.""" + def __init__(self, args: Namespace, model: torch.nn.Module) -> None: + self.args = args + self.model = model + def connect_rollout_engines( self, rollout_engines: Sequence[ActorHandle], @@ -267,7 +220,7 @@ def connect_rollout_engines( ) ray.get(refs) - def update_bucket_weights(self, named_tensors, weight_version=None) -> None: + def update_bucket_weights(self, named_tensors) -> None: """Send names/dtypes/shapes metadata to engines, then broadcast tensors. Ensures tensors are contiguous; when `world_size == 1`, converts DTensors @@ -282,7 +235,6 @@ def update_bucket_weights(self, named_tensors, weight_version=None) -> None: dtypes=[param.dtype for _, param in named_tensors], shapes=[param.shape for _, param in named_tensors], group_name=self._group_name, - weight_version=str(weight_version), ) for engine in self.rollout_engines ] diff --git a/miles/backends/sglang_utils/sglang_engine.py b/miles/backends/sglang_utils/sglang_engine.py index c9a774b32..2e1afe625 100644 --- a/miles/backends/sglang_utils/sglang_engine.py +++ b/miles/backends/sglang_utils/sglang_engine.py @@ -278,15 +278,9 @@ def get_weight_version(self): response.raise_for_status() return response.json()["weight_version"] - def release_memory_occupation(self, tags: list[str] = None): - """ - Available tags for multi-stage resume: weights, kv_cache - """ + def release_memory_occupation(self): self.flush_cache() - return self._make_request( - "release_memory_occupation", - {"tags": tags}, - ) + return self._make_request("release_memory_occupation") def resume_memory_occupation(self, tags: list[str] = None): """ @@ -342,18 +336,6 @@ def update_weights_from_distributed( payload, ) - def load_lora_adapter(self, lora_name: str, lora_path: str): - return self._make_request( - "load_lora_adapter", - {"lora_name": lora_name, "lora_path": lora_path}, - ) - - def unload_lora_adapter(self, lora_name: str): - return self._make_request( - "unload_lora_adapter", - {"lora_name": lora_name}, - ) - def pause_generation(self): response = requests.post(f"http://{self.server_host}:{self.server_port}/pause_generation", json={}) response.raise_for_status() @@ -437,10 +419,6 @@ def _compute_server_args(args, rank, dist_init_addr, nccl_port, host, port, work kwargs["enable_return_routed_experts"] = True if args.fp16: kwargs["dtype"] = "float16" - if args.lora_rank > 0 or args.lora_adapter_path is not None: - kwargs["enable_lora"] = True - kwargs["max_lora_rank"] = args.lora_rank - kwargs["lora_target_modules"] = args.target_modules external_engine_need_check_fields = [k for k in kwargs.keys() if k not in _EXTERNAL_ENGINE_SKIP_CHECK_FIELDS] diff --git a/miles/ray/placement_group.py b/miles/ray/placement_group.py index 7bd842960..b6fb7a20b 100644 --- a/miles/ray/placement_group.py +++ b/miles/ray/placement_group.py @@ -177,7 +177,6 @@ def create_rollout_manager(args, pg): ray.get(rollout_manager.check_weights.remote(action="reset_tensors")) if args.offload_rollout: - # TODO: Optimization in the future: offload model weights to cpu to make more space for training? ray.get(rollout_manager.offload.remote()) return rollout_manager, num_rollout_per_epoch diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index cb82c8ce8..9ee0fbb8a 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -127,8 +127,8 @@ def save(self, rollout_id): def load(self, rollout_id=None): self.data_source.load(rollout_id) - def offload(self, tags: list[str] = None): - return ray.get([engine.release_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) + def offload(self): + return ray.get([engine.release_memory_occupation.remote() for engine in self.rollout_engines]) def onload(self, tags: list[str] = None): return ray.get([engine.resume_memory_occupation.remote(tags=tags) for engine in self.rollout_engines]) @@ -412,7 +412,7 @@ def init_rollout_engines(args, pg, all_rollout_engines): num_new_engines = len(rollout_engines) if num_new_engines == 0: - return num_new_engines + return num_new_engines, None if args.rollout_external: addr_and_ports = _allocate_rollout_engine_addr_and_ports_external(args=args, rollout_engines=rollout_engines) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 36f6e7ce0..2e33542a5 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -11,7 +11,6 @@ from packaging.version import parse from tqdm import tqdm -from miles.backends.fsdp_utils.lora_utils import LORA_ADAPTER_NAME from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput from miles.rollout.filter_hub.base_types import DynamicFilterOutput from miles.utils.async_utils import run @@ -125,10 +124,6 @@ async def generate(args: Namespace, sample: Sample, sampling_params: dict[str, A "return_logprob": True, } - # Use LoRA adapter when LoRA is enabled - if args.lora_rank > 0 or args.lora_adapter_path is not None: - payload["lora_path"] = LORA_ADAPTER_NAME - if args.use_rollout_routing_replay: payload["return_routed_experts"] = True diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 356b91a0e..ce6e47161 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -104,17 +104,6 @@ def add_cluster_arguments(parser): "This will always be true when --colocate is set." ), ) - parser.add_argument( - "--offload-rollout-level", - type=str, - nargs="+", - default=["kv_cache", "weight"], - help=( - "Specifies what to offload during rollout when offload-rollout is set. " - "Possible values: 'kv_cache', 'weight'. Default: both 'kv_cache' and 'weight'. " - "Example: --offload-rollout-level kv_cache weight" - ), - ) reset_arg(parser, "--distributed-backend", type=str, default="nccl") reset_arg(parser, "--distributed-timeout-minutes", type=int, default=10) @@ -1426,27 +1415,6 @@ def miles_validate_args(args): if args.save_interval is not None: assert args.save is not None, "'--save' is required when save_interval is set." - if args.lora_rank > 0: - assert args.save is not None, "'--save' is required when LoRA is enabled." - assert args.target_modules is not None, "'--target-modules' is required when LoRA is enabled." - - if args.target_modules == "all-linear": - modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] - elif "," in args.target_modules: - modules = [m.strip() for m in args.target_modules.split(",")] - else: - modules = [args.target_modules] - - if args.exclude_modules: - exclude_set = ( - set(m.strip() for m in args.exclude_modules.split(",")) - if "," in args.exclude_modules - else {args.exclude_modules} - ) - modules = [m for m in modules if m not in exclude_set] - - args.target_modules = modules - assert not (args.kl_coef != 0 and args.kl_loss_coef != 0), "Only one of kl_coef and kl_loss_coef can be set" if args.advantage_estimator in ["reinforce_plus_plus", "reinforce_plus_plus_baseline"]: diff --git a/requirements.txt b/requirements.txt index 3840f294d..2c20195fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,6 @@ httpx[http2] mcp[cli] memray # needed for debugging (but is lightweight), we can put it to dev mode when using pyproject.toml omegaconf -peft pillow pylatexenc pyyaml diff --git a/tests/test_external_rollout.py b/tests/test_external_rollout.py index f12837d88..c5c0838c5 100644 --- a/tests/test_external_rollout.py +++ b/tests/test_external_rollout.py @@ -14,7 +14,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"huggingface-cli download Qwen/Qwen2.5-0.5B-Instruct --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") U.convert_checkpoint(model_name=MODEL_NAME, megatron_model_type=MODEL_TYPE, num_gpus_per_node=NUM_GPUS) diff --git a/tests/test_qwen2.5_0.5B_gsm8k.py b/tests/test_qwen2.5_0.5B_gsm8k.py index 26d3e3197..6302aadb6 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k.py +++ b/tests/test_qwen2.5_0.5B_gsm8k.py @@ -11,7 +11,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen2.5_0.5B_gsm8k_async.py b/tests/test_qwen2.5_0.5B_gsm8k_async.py index 878d68b1c..1c55ccb20 100644 --- a/tests/test_qwen2.5_0.5B_gsm8k_async.py +++ b/tests/test_qwen2.5_0.5B_gsm8k_async.py @@ -10,7 +10,7 @@ def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") diff --git a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py index 460f8a5a5..6967f9145 100644 --- a/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py +++ b/tests/test_qwen3_0.6B_fsdp_colocated_2xGPU.py @@ -1,29 +1,17 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" -ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " - lora_args = ( - ( - "--lora-rank 32 " - "--lora-alpha 32 " - "--target-modules all-linear " - f"--save /root/models/{MODEL_NAME}-lora-ckpt " - ) - if ENABLE_LORA - else "" - ) - rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -62,7 +50,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " + "--lr 1e-6 " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -85,17 +73,10 @@ def execute(): "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step ) - misc_args = ( - "--actor-num-nodes 1 " - "--actor-num-gpus-per-node 2 " - "--colocate " - "--offload-rollout-level kv_cache weight " - "--train-backend fsdp " - ) + misc_args = "--actor-num-nodes 1 " "--actor-num-gpus-per-node 2 " "--colocate " "--train-backend fsdp " train_args = ( f"{ckpt_args} " - f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/tests/test_qwen3_0.6B_fsdp_distributed.py b/tests/test_qwen3_0.6B_fsdp_distributed.py index c4592ffdf..b3eb416b3 100644 --- a/tests/test_qwen3_0.6B_fsdp_distributed.py +++ b/tests/test_qwen3_0.6B_fsdp_distributed.py @@ -1,30 +1,20 @@ import miles.utils.external_utils.command_utils as U MODEL_NAME = "Qwen3-0.6B" -ENABLE_LORA = U.get_bool_env_var("ENABLE_LORA", "0") + + FEW_GPU = U.get_bool_env_var("MILES_TEST_FEW_GPU", "1") def prepare(): U.exec_command("mkdir -p /root/models /root/datasets") - U.exec_command(f"hf download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") U.hf_download_dataset("zhuzilin/gsm8k") def execute(): ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " - lora_args = ( - ( - "--lora-rank 32 " - "--lora-alpha 32 " - "--target-modules all-linear " - f"--save /root/models/{MODEL_NAME}-lora-ckpt " - ) - if ENABLE_LORA - else "" - ) - rollout_args = ( "--prompt-data /root/datasets/gsm8k/train.parquet " "--input-key messages " @@ -64,7 +54,7 @@ def execute(): optimizer_args = ( "--optimizer adam " - f"--lr {'2e-5' if ENABLE_LORA else '1e-6'} " + "--lr 1e-6 " "--lr-decay-style constant " "--weight-decay 0.1 " "--adam-beta1 0.9 " @@ -77,7 +67,6 @@ def execute(): "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {1 if FEW_GPU else 2} " f"--rollout-num-gpus {1 if FEW_GPU else 2} " - "--offload-rollout-level kv_cache weight " "--train-backend fsdp " ) @@ -90,7 +79,6 @@ def execute(): train_args = ( f"{ckpt_args} " - f"{lora_args} " f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " diff --git a/train.py b/train.py index 212361fc0..9fb480eda 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def offload_train(): actor_model.clear_memory() def onload_rollout(): - if args.offload_rollout and "weight" in args.offload_rollout_level: + if args.offload_rollout: ray.get(rollout_manager.onload.remote(tags=[GPU_MEMORY_TYPE_WEIGHTS])) # train loop. @@ -68,12 +68,7 @@ def onload_rollout(): rollout_data_ref = ray.get(rollout_manager.generate.remote(rollout_id)) if args.offload_rollout: - offload_tags = [GPU_MEMORY_TYPE_CUDA_GRAPH] - if "kv_cache" in args.offload_rollout_level: - offload_tags.append(GPU_MEMORY_TYPE_KV_CACHE) - if "weight" in args.offload_rollout_level: - offload_tags.append(GPU_MEMORY_TYPE_WEIGHTS) - ray.get(rollout_manager.offload.remote(tags=offload_tags)) + ray.get(rollout_manager.offload.remote()) if args.use_critic: critic_train_handle = critic_model.async_train(rollout_id, rollout_data_ref) From 9a3b29711d2ca8e8bc789a1fdc0a44c839623c35 Mon Sep 17 00:00:00 2001 From: Ratish P <114130421+Ratish1@users.noreply.github.com> Date: Wed, 31 Dec 2025 07:40:25 +0400 Subject: [PATCH 17/21] feat: Implement lazy data loading for Dataset (#246) Co-authored-by: zhaochenyang20 Co-authored-by: PopSoda2002 --- miles/ray/rollout_data_source.py | 1 + miles/rollout/data_source.py | 1 + miles/rollout/sglang_rollout.py | 1 + miles/utils/arguments.py | 6 ++ miles/utils/data.py | 144 +++++++++++++++++++++++-------- 5 files changed, 115 insertions(+), 38 deletions(-) diff --git a/miles/ray/rollout_data_source.py b/miles/ray/rollout_data_source.py index c9df08f4f..e962a29f8 100644 --- a/miles/ray/rollout_data_source.py +++ b/miles/ray/rollout_data_source.py @@ -43,6 +43,7 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, + num_proc=args.num_proc, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) diff --git a/miles/rollout/data_source.py b/miles/rollout/data_source.py index 613319d34..ca9b80f94 100644 --- a/miles/rollout/data_source.py +++ b/miles/rollout/data_source.py @@ -75,6 +75,7 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, + num_proc=args.num_proc, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2e33542a5..6ed83092b 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -478,6 +478,7 @@ async def eval_rollout_single_dataset( tool_key=dataset_cfg.tool_key, apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, + num_proc=args.num_proc, ) dataset = EVAL_PROMPT_DATASET[cache_key] diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index ce6e47161..07b03fcbc 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -577,6 +577,12 @@ def add_data_arguments(parser): "and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. " ), ) + parser.add_argument( + "--num-proc", + type=int, + default=8, + help="Number of processes for dataset initialization and filtering.", + ) return parser def add_eval_arguments(parser): diff --git a/miles/utils/data.py b/miles/utils/data.py index c36902c81..e26cee33c 100644 --- a/miles/utils/data.py +++ b/miles/utils/data.py @@ -1,9 +1,10 @@ import json import logging import os -import random import re +from functools import partial +import datasets import numpy as np import pandas as pd import ray @@ -16,6 +17,16 @@ logger = logging.getLogger(__name__) +_FILE_TYPE_MAP = { + ".jsonl": "json", + ".parquet": "parquet", +} + + +def _filter_func(example, tokenizer, processor, max_length, prompt_key, multimodal_keys, apply_chat_template_kwargs): + prompt = _build_messages(example, prompt_key, multimodal_keys) + return not _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs) + # TODO: don't read the whole file into memory. def read_file(path): @@ -124,53 +135,110 @@ def __init__( seed=42, apply_chat_template=False, apply_chat_template_kwargs=None, + num_proc=8, ): - self.origin_samples = [] - for data in read_file(path): - prompt = _build_messages(data, prompt_key, multimodal_keys) - - metadata = data.get(metadata_key) or {} - if tool_key is not None and tool_key in data: - tools = data[tool_key] - if isinstance(tools, str): - tools = json.loads(tools) - elif isinstance(tools, np.ndarray): - tools = tools.tolist() - assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" - metadata["tools"] = tools - - # TODO: this is slow. - if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): - continue + # 1. Store basic config + self.tokenizer = tokenizer + self.processor = processor + self.max_length = max_length + self.prompt_key = prompt_key + self.multimodal_keys = multimodal_keys + self.label_key = label_key + self.tool_key = tool_key + self.metadata_key = metadata_key + self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self.seed = seed + self.epoch_id = -1 - self.origin_samples.append( - Sample( - prompt=prompt, - label=data[label_key] if label_key is not None else None, - metadata=metadata, - ) - ) + # 2. Load and process dataset + self.hf_dataset = self._load_and_filter_dataset(path, num_proc) + self.origin_hf_dataset = self.hf_dataset - self.epoch_id = -1 - self.seed = seed - self.samples = self.origin_samples + def _get_file_type(self, path: str) -> str: + _, ext = os.path.splitext(path) + + try: + return _FILE_TYPE_MAP[ext] + except KeyError: + raise ValueError(f"Unsupported format: {ext}. Supported: {list(_FILE_TYPE_MAP.keys())}") from None + + def _load_and_filter_dataset(self, path, num_proc): + raw_file_path, row_slice = _parse_generalized_path(path) + + if not os.path.exists(raw_file_path): + raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.") + + logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.") + + # Determine file type and load using datasets library for memory-mapped access + file_type = self._get_file_type(raw_file_path) + ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train") + + # Apply row slicing if specified + if row_slice: + num_rows = len(ds) + indices = range(num_rows)[row_slice] + ds = ds.select(indices) + logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}") + + filter_kwargs = { + "tokenizer": self.tokenizer, + "processor": self.processor, + "max_length": self.max_length, + "prompt_key": self.prompt_key, + "multimodal_keys": self.multimodal_keys, + "apply_chat_template_kwargs": self.apply_chat_template_kwargs, + } + + original_size = len(ds) + + ds = ds.filter(partial(_filter_func, **filter_kwargs), num_proc=num_proc, desc="Filtering invalid samples") + + new_size = len(ds) + logger.info(f"Filtered dataset from {original_size} to {new_size} samples.") + + return ds + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + # The underlying HF dataset handles lazy fetching + data = self.hf_dataset[idx] + + # Process the data using existing logic + prompt = _build_messages(data, self.prompt_key, self.multimodal_keys) + + metadata = data.get(self.metadata_key) or {} + if self.tool_key is not None and self.tool_key in data: + tools = data[self.tool_key] + if isinstance(tools, str): + tools = json.loads(tools) + # TODO (chenyang): If the JSON parsing is heavy, we might need + # to use hf_dataset.map() during init to pre-process these + # fields into a more efficient format (Arrow-native), rather + # than parsing raw strings on the fly. + elif isinstance(tools, np.ndarray): + tools = tools.tolist() + assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" + metadata["tools"] = tools + + sample = Sample( + prompt=prompt, + label=data.get(self.label_key) if self.label_key is not None else None, + metadata=metadata, + ) + + return sample def shuffle(self, new_epoch_id): if self.epoch_id == new_epoch_id: return - random.seed(self.seed + new_epoch_id) - permutation = list(range(len(self.samples))) - random.shuffle(permutation) - self.samples = [self.origin_samples[i] for i in permutation] + logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}") + self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id) self.epoch_id = new_epoch_id - def __getitem__(self, idx): - return self.samples[idx] - - def __len__(self): - return len(self.samples) - def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): # use first fit to get the number of micro batches From b0d3341be1771e1ed63583ab9f62f739da472c52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B5=B5=E6=99=A8=E9=98=B3?= Date: Tue, 30 Dec 2025 20:09:35 -0800 Subject: [PATCH 18/21] Revert "feat: Implement lazy data loading for Dataset" (#372) --- miles/ray/rollout_data_source.py | 1 - miles/rollout/data_source.py | 1 - miles/rollout/sglang_rollout.py | 1 - miles/utils/arguments.py | 6 -- miles/utils/data.py | 144 ++++++++----------------------- 5 files changed, 38 insertions(+), 115 deletions(-) diff --git a/miles/ray/rollout_data_source.py b/miles/ray/rollout_data_source.py index e962a29f8..c9df08f4f 100644 --- a/miles/ray/rollout_data_source.py +++ b/miles/ray/rollout_data_source.py @@ -43,7 +43,6 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, - num_proc=args.num_proc, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) diff --git a/miles/rollout/data_source.py b/miles/rollout/data_source.py index ca9b80f94..613319d34 100644 --- a/miles/rollout/data_source.py +++ b/miles/rollout/data_source.py @@ -75,7 +75,6 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, - num_proc=args.num_proc, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 6ed83092b..2e33542a5 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -478,7 +478,6 @@ async def eval_rollout_single_dataset( tool_key=dataset_cfg.tool_key, apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, - num_proc=args.num_proc, ) dataset = EVAL_PROMPT_DATASET[cache_key] diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 07b03fcbc..ce6e47161 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -577,12 +577,6 @@ def add_data_arguments(parser): "and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. " ), ) - parser.add_argument( - "--num-proc", - type=int, - default=8, - help="Number of processes for dataset initialization and filtering.", - ) return parser def add_eval_arguments(parser): diff --git a/miles/utils/data.py b/miles/utils/data.py index e26cee33c..c36902c81 100644 --- a/miles/utils/data.py +++ b/miles/utils/data.py @@ -1,10 +1,9 @@ import json import logging import os +import random import re -from functools import partial -import datasets import numpy as np import pandas as pd import ray @@ -17,16 +16,6 @@ logger = logging.getLogger(__name__) -_FILE_TYPE_MAP = { - ".jsonl": "json", - ".parquet": "parquet", -} - - -def _filter_func(example, tokenizer, processor, max_length, prompt_key, multimodal_keys, apply_chat_template_kwargs): - prompt = _build_messages(example, prompt_key, multimodal_keys) - return not _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs) - # TODO: don't read the whole file into memory. def read_file(path): @@ -135,110 +124,53 @@ def __init__( seed=42, apply_chat_template=False, apply_chat_template_kwargs=None, - num_proc=8, ): - # 1. Store basic config - self.tokenizer = tokenizer - self.processor = processor - self.max_length = max_length - self.prompt_key = prompt_key - self.multimodal_keys = multimodal_keys - self.label_key = label_key - self.tool_key = tool_key - self.metadata_key = metadata_key - self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} - self.seed = seed - self.epoch_id = -1 - - # 2. Load and process dataset - self.hf_dataset = self._load_and_filter_dataset(path, num_proc) - self.origin_hf_dataset = self.hf_dataset - - def _get_file_type(self, path: str) -> str: - _, ext = os.path.splitext(path) - - try: - return _FILE_TYPE_MAP[ext] - except KeyError: - raise ValueError(f"Unsupported format: {ext}. Supported: {list(_FILE_TYPE_MAP.keys())}") from None - - def _load_and_filter_dataset(self, path, num_proc): - raw_file_path, row_slice = _parse_generalized_path(path) - - if not os.path.exists(raw_file_path): - raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.") - - logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.") - - # Determine file type and load using datasets library for memory-mapped access - file_type = self._get_file_type(raw_file_path) - ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train") - - # Apply row slicing if specified - if row_slice: - num_rows = len(ds) - indices = range(num_rows)[row_slice] - ds = ds.select(indices) - logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}") - - filter_kwargs = { - "tokenizer": self.tokenizer, - "processor": self.processor, - "max_length": self.max_length, - "prompt_key": self.prompt_key, - "multimodal_keys": self.multimodal_keys, - "apply_chat_template_kwargs": self.apply_chat_template_kwargs, - } - - original_size = len(ds) - - ds = ds.filter(partial(_filter_func, **filter_kwargs), num_proc=num_proc, desc="Filtering invalid samples") - - new_size = len(ds) - logger.info(f"Filtered dataset from {original_size} to {new_size} samples.") - - return ds + self.origin_samples = [] + for data in read_file(path): + prompt = _build_messages(data, prompt_key, multimodal_keys) + + metadata = data.get(metadata_key) or {} + if tool_key is not None and tool_key in data: + tools = data[tool_key] + if isinstance(tools, str): + tools = json.loads(tools) + elif isinstance(tools, np.ndarray): + tools = tools.tolist() + assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" + metadata["tools"] = tools + + # TODO: this is slow. + if _should_skip_prompt(prompt, tokenizer, processor, max_length, apply_chat_template_kwargs): + continue - def __len__(self): - return len(self.hf_dataset) + self.origin_samples.append( + Sample( + prompt=prompt, + label=data[label_key] if label_key is not None else None, + metadata=metadata, + ) + ) - def __getitem__(self, idx): - # The underlying HF dataset handles lazy fetching - data = self.hf_dataset[idx] - - # Process the data using existing logic - prompt = _build_messages(data, self.prompt_key, self.multimodal_keys) - - metadata = data.get(self.metadata_key) or {} - if self.tool_key is not None and self.tool_key in data: - tools = data[self.tool_key] - if isinstance(tools, str): - tools = json.loads(tools) - # TODO (chenyang): If the JSON parsing is heavy, we might need - # to use hf_dataset.map() during init to pre-process these - # fields into a more efficient format (Arrow-native), rather - # than parsing raw strings on the fly. - elif isinstance(tools, np.ndarray): - tools = tools.tolist() - assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" - metadata["tools"] = tools - - sample = Sample( - prompt=prompt, - label=data.get(self.label_key) if self.label_key is not None else None, - metadata=metadata, - ) - - return sample + self.epoch_id = -1 + self.seed = seed + self.samples = self.origin_samples def shuffle(self, new_epoch_id): if self.epoch_id == new_epoch_id: return - logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}") - self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id) + random.seed(self.seed + new_epoch_id) + permutation = list(range(len(self.samples))) + random.shuffle(permutation) + self.samples = [self.origin_samples[i] for i in permutation] self.epoch_id = new_epoch_id + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) + def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): # use first fit to get the number of micro batches From 8ba715e5714dd83d6335b6136d8d59e1f61ba396 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Dec 2025 20:48:59 -0800 Subject: [PATCH 19/21] [MISC] add codeowners (#373) --- .github/CODEOWNERS | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..5d094d3de --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,2 @@ +.github/ @fzyzcjy @yushengsu-thu @Ying1123 +/miles/ @fzyzcjy @yueming-yuan From bc61a7d55fcd2739dbba565949022ae808fa7ed0 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Tue, 30 Dec 2025 21:01:23 -0800 Subject: [PATCH 20/21] [Misc] update codeowners (#374) --- .github/CODEOWNERS | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 5d094d3de..dc0cc7cbc 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,3 @@ -.github/ @fzyzcjy @yushengsu-thu @Ying1123 +.github/CODEOWNERS @fzyzcjy @Ying1123 +.github/workflows/ @yushengsu-thu /miles/ @fzyzcjy @yueming-yuan From 66c14c48e05ca8eaaa94994e9b0068ba31f69f9d Mon Sep 17 00:00:00 2001 From: mihir <78321484+maharajamihir@users.noreply.github.com> Date: Tue, 23 Dec 2025 16:18:27 +0100 Subject: [PATCH 21/21] Added uv build scripts (#2) * added build uv scripts * fix installation scripts * simpler venv naming * modify uv install scripts * remove hardcoded path for basedir * address franz' points --- build_uv_berlin.sh | 194 +++++++++++++++++++++++++++++++++++++++++++ build_uv_juelich.sh | 198 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 392 insertions(+) create mode 100644 build_uv_berlin.sh create mode 100644 build_uv_juelich.sh diff --git a/build_uv_berlin.sh b/build_uv_berlin.sh new file mode 100644 index 000000000..c4f6cd7cd --- /dev/null +++ b/build_uv_berlin.sh @@ -0,0 +1,194 @@ +#!/bin/bash + +# ============================================================================= +# Miles Build Script (CUDA 12.8 Slurm Version) +# +# This script uses uv for Python environment management. +# It relies on the Slurm module system for CUDA 12.8 instead of pip packages. +# +# Configuration: +# - CUDA: 12.8 (via module load) +# - PyTorch: 2.8.0 (cu128) +# - Flash Attention 3: Prebuilt for cu128 + torch2.8 +# - Flash Attention 2: Prebuilt for cu128 + torch2.8 +# ============================================================================= + +set -e # Exit on error + +BASE_DIR="$(pwd)/.." # change this if you want a different base directory. Default is parent directory of miles + +if [ -z "$BASE_DIR" ]; then + echo "BASE_DIR is not set. Please set it to proceed with the installation." + exit 1 +fi + +# ============================================================================= +# Load Slurm Module +# ============================================================================= +echo "Loading CUDA 12.8 module..." +module load CUDA/12.8 + +# Verify NVCC is in path +if ! command -v nvcc &> /dev/null; then + echo "CRITICAL ERROR: nvcc not found after loading module." + echo "Please ensure 'module load CUDA/12.8' works on this cluster." + exit 1 +fi + +# ============================================================================= +# Install uv if not already installed +# ============================================================================= +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# ============================================================================= +# Create Python virtual environment with uv +# ============================================================================= +# Create virtual environment with Python 3.12 +uv venv --python 3.12 + +# Activate the virtual environment +source ".venv/bin/activate" + +cd "$BASE_DIR" +# ============================================================================= +# Install PyTorch with CUDA 12.8 +# ============================================================================= +echo "Installing PyTorch 2.8.0 with CUDA 12.8..." + +# Install cuda-python (Pinned to 12.8 to match the module) +uv pip install cuda-python==12.8.0 + +# Install PyTorch 2.8.0 for CUDA 12.8 +uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu128 + +# Set TORCH_CUDA_ARCH_LIST for our GPU architectures +# 8.0 = A100 (Ampere), 9.0 = H100 (Hopper) +export TORCH_CUDA_ARCH_LIST="8.0;9.0" + +# ============================================================================= +# Install sglang +# ============================================================================= +echo "Installing sglang..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/sglang" ]; then + git clone https://github.com/sgl-project/sglang.git +fi +cd sglang +git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +uv pip install -e "python[all]" + +# ============================================================================= +# Install build tools +# ============================================================================= +uv pip install cmake ninja packaging build wheel + +# ============================================================================= +# Install Flash Attention 3 (prebuilt wheels for cu128 + torch2.8) +# ============================================================================= +echo "Installing Flash Attention 3..." +# Using windreamer's wheel index for CUDA 12.8 and PyTorch 2.8.0 +uv pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch280 --extra-index-url https://download.pytorch.org/whl/cu128 + +# ============================================================================= +# Install Flash Attention 2 (prebuilt wheel for Megatron compatibility) +# ============================================================================= +echo "Installing Flash Attention 2..." +uv pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.18/flash_attn-2.7.4%2Bcu128torch2.8-cp312-cp312-linux_x86_64.whl + +# ============================================================================= +# Install mbridge, transformer_engine, flash-linear-attention +# ============================================================================= +echo "Installing mbridge, transformer_engine, flash-linear-attention..." +uv pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +# Transformer Engine 2.8.0 (compatible with CUDA 12.8) +uv pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" --no-cache-dir + +uv pip install flash-linear-attention==0.4.0 + +# ============================================================================= +# Install NVIDIA Apex (requires CUDA compilation) +# ============================================================================= +echo "Installing NVIDIA Apex (Compiling from source)..." +# We utilize the loaded CUDA module for compilation +NVCC_APPEND_FLAGS="--threads 4" \ + APEX_CPP_EXT=1 \ + APEX_CUDA_EXT=1 \ + APEX_PARALLEL_BUILD=8 \ + uv pip install -v --no-cache-dir \ + --no-build-isolation \ + git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +# ============================================================================= +# Install Megatron-LM +# ============================================================================= +echo "Installing Megatron-LM..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/Megatron-LM" ]; then + git clone https://github.com/NVIDIA/Megatron-LM.git --recursive +fi +cd Megatron-LM +git checkout core_v0.14.0 +uv pip install -e . + +# ============================================================================= +# Install additional dependencies +# ============================================================================= +echo "Installing additional dependencies..." +uv pip install poetry pybind11 +uv pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +uv pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +uv pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation + +# ============================================================================= +# Install remaining packages +# ============================================================================= +uv pip install sglang_router ring_flash_attn pylatexenc +uv pip install -U "ray[data,train,tune,serve]" + +# ============================================================================= +# Install miles +# ============================================================================= +echo "Installing miles..." +if [ ! -d "$BASE_DIR/miles" ]; then + cd "$BASE_DIR" + git clone https://github.com/radixark/miles.git + cd miles/ + export MILES_DIR="$BASE_DIR/miles" + uv pip install -e . +elif [ -f "$BASE_DIR/pyproject.toml" ]; then + export MILES_DIR="$BASE_DIR" + cd "$MILES_DIR" + uv pip install -e . +else + export MILES_DIR="$BASE_DIR/miles" + cd "$MILES_DIR" + uv pip install -e . +fi + +# ============================================================================= +# Apply patches +# ============================================================================= +echo "Applying patches..." +cd "$BASE_DIR/sglang" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch" || echo "sglang patch already applied or failed" + +cd "$BASE_DIR/Megatron-LM" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch" || echo "Megatron patch already applied or failed" + +echo "" +echo "=============================================================================" +echo "Installation complete!" +echo "" +echo "To activate the environment, run:" +echo " module load CUDA/12.8" +echo " source $BASE_DIR/miles-venv/bin/activate" +echo "" +echo "Environment configured using System CUDA from 'module load CUDA/12.8'" +echo "PyTorch Version: 2.8.0 (cu128)" +echo "CUDA_HOME: $CUDA_HOME" +echo "=============================================================================" \ No newline at end of file diff --git a/build_uv_juelich.sh b/build_uv_juelich.sh new file mode 100644 index 000000000..9ba6830d2 --- /dev/null +++ b/build_uv_juelich.sh @@ -0,0 +1,198 @@ +#!/bin/bash + +# ============================================================================= +# Miles Build Script (CUDA 12.6 Slurm Version) +# +# This script uses uv for Python environment management. +# It relies on the Slurm module system for CUDA 12.6 instead of pip packages. +# +# Configuration: +# - CUDA: 12.6 (via module load) +# - PyTorch: 2.8.0 (cu126) +# - Flash Attention 3: Prebuilt for cu126 + torch2.8 +# - Flash Attention 2: Prebuilt for cu128 + torch2.8 +# ============================================================================= + +set -e # Exit on error + +BASE_DIR="$(pwd)/.." # change this if you want a different base directory. Default is parent directory of miles + +if [ -z "$BASE_DIR" ]; then + echo "BASE_DIR is not set. Please set it to proceed with the installation." + exit 1 +fi + +# ============================================================================= +# Load Slurm Module +# ============================================================================= +echo "Loading CUDA 12.6 module..." +module load CUDA/12 +module load cuDNN/9.5.0.50-CUDA-12 +module load NCCL/default-CUDA-12 +module load Clang/18.1.8 + +# Verify NVCC is in path +if ! command -v nvcc &> /dev/null; then + echo "CRITICAL ERROR: nvcc not found after loading module." + echo "Please ensure 'module load CUDA/12' works on this cluster." + exit 1 +fi + +# ============================================================================= +# Install uv if not already installed +# ============================================================================= +if ! command -v uv &> /dev/null; then + echo "Installing uv..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +fi + +# ============================================================================= +# Create Python virtual environment with uv +# ============================================================================= +# Create virtual environment with Python 3.12 +uv venv --python 3.12 + +# Activate the virtual environment +source ".venv/bin/activate" + +cd "$BASE_DIR" +# ============================================================================= +# Install PyTorch with CUDA 12.6 +# ============================================================================= +echo "Installing PyTorch 2.8.0 with CUDA 12.6..." + +# Install cuda-python (Pinned to 12.6 to match the module) +uv pip install cuda-python==12.6.0 + +# Install PyTorch 2.8.0 for CUDA 12.6 +uv pip install torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu126 + +# Set TORCH_CUDA_ARCH_LIST for our GPU architectures +# 8.0 = A100 (Ampere), 9.0 = H100 (Hopper) +export TORCH_CUDA_ARCH_LIST="8.0;9.0" + +# ============================================================================= +# Install sglang +# ============================================================================= +echo "Installing sglang..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/sglang" ]; then + git clone https://github.com/sgl-project/sglang.git +fi +cd sglang +git checkout 303cc957e62384044dfa8e52d7d8af8abe12f0ac +uv pip install -e "python[all]" + +# ============================================================================= +# Install build tools +# ============================================================================= +uv pip install cmake ninja packaging build wheel + +# ============================================================================= +# Install Flash Attention 3 (prebuilt wheels for cu126 + torch2.8) +# ============================================================================= +echo "Installing Flash Attention 3..." +# Using windreamer's wheel index for CUDA 12.6 and PyTorch 2.8.0 +uv pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu126_torch280 --extra-index-url https://download.pytorch.org/whl/cu126 +# ============================================================================= +# Install Flash Attention 2 (prebuilt wheel for Megatron compatibility) +# ============================================================================= +echo "Installing Flash Attention 2..." +uv pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.8-cp312-cp312-linux_x86_64.whl + +# ============================================================================= +# Install mbridge, transformer_engine, flash-linear-attention +# ============================================================================= +echo "Installing mbridge, transformer_engine, flash-linear-attention..." +uv pip install git+https://github.com/ISEEKYAN/mbridge.git@89eb10887887bc74853f89a4de258c0702932a1c --no-deps + +# Transformer Engine 2.8.0 (compatible with CUDA 12.8) +uv pip install --no-build-isolation "transformer_engine[pytorch]==2.8.0" --no-cache-dir + +uv pip install flash-linear-attention==0.4.0 + +# ============================================================================= +# Install NVIDIA Apex (requires CUDA compilation) +# ============================================================================= +echo "Installing NVIDIA Apex (Compiling from source)..." +# We utilize the loaded CUDA module for compilation +NVCC_APPEND_FLAGS="--threads 4" \ + APEX_CPP_EXT=1 \ + APEX_CUDA_EXT=1 \ + APEX_PARALLEL_BUILD=8 \ + uv pip install -v --no-cache-dir \ + --no-build-isolation \ + git+https://github.com/NVIDIA/apex.git@10417aceddd7d5d05d7cbf7b0fc2daad1105f8b4 + +# ============================================================================= +# Install Megatron-LM +# ============================================================================= +echo "Installing Megatron-LM..." +cd "$BASE_DIR" +if [ ! -d "$BASE_DIR/Megatron-LM" ]; then + git clone https://github.com/NVIDIA/Megatron-LM.git --recursive +fi +cd Megatron-LM +git checkout core_v0.14.0 +uv pip install -e . + +# ============================================================================= +# Install additional dependencies +# ============================================================================= +echo "Installing additional dependencies..." +uv pip install poetry pybind11 +uv pip install git+https://github.com/fzyzcjy/torch_memory_saver.git@9b8b788fdeb9c2ee528183214cef65a99b71e7d5 --no-cache-dir --force-reinstall +uv pip install git+https://github.com/fzyzcjy/Megatron-Bridge.git@dev_rl --no-build-isolation +uv pip install "nvidia-modelopt[torch]>=0.37.0" --no-build-isolation + +# ============================================================================= +# Install remaining packages +# ============================================================================= +uv pip install sglang_router ring_flash_attn pylatexenc +uv pip install -U "ray[data,train,tune,serve]" + +# ============================================================================= +# Install miles +# ============================================================================= +echo "Installing miles..." +if [ ! -d "$BASE_DIR/miles" ]; then + cd "$BASE_DIR" + git clone https://github.com/radixark/miles.git + cd miles/ + export MILES_DIR="$BASE_DIR/miles" + uv pip install -e . +elif [ -f "$BASE_DIR/pyproject.toml" ]; then + export MILES_DIR="$BASE_DIR" + cd "$MILES_DIR" + uv pip install -e . +else + export MILES_DIR="$BASE_DIR/miles" + cd "$MILES_DIR" + uv pip install -e . +fi + +# ============================================================================= +# Apply patches +# ============================================================================= +echo "Applying patches..." +cd "$BASE_DIR/sglang" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/sglang.patch" || echo "sglang patch already applied or failed" + +cd "$BASE_DIR/Megatron-LM" +git apply "$MILES_DIR/docker/patch/v0.5.5.post1/megatron.patch" || echo "Megatron patch already applied or failed" + +echo "" +echo "=============================================================================" +echo "Installation complete!" +echo "" +echo "To activate the environment, run:" +echo " module load CUDA/12" +echo " module load cuDNN/9.5.0.50-CUDA-12" +echo " module load NCCL/default-CUDA-12" +echo " source $BASE_DIR/venv/bin/activate" +echo "" +echo "Environment configured using System CUDA from 'module load CUDA/12'" +echo "PyTorch Version: 2.8.0 (cu126)" +echo "CUDA_HOME: $CUDA_HOME" +echo "=============================================================================" \ No newline at end of file