From 22b27853f334ac6790bf64374f643757cd73933c Mon Sep 17 00:00:00 2001 From: millioniron Date: Wed, 3 Dec 2025 14:08:18 +0800 Subject: [PATCH 01/58] new file: docs_roll/docs/User Guides/Configuration/infer_correction.md new file: examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file: examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file: examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file: examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh modified: roll/configs/base_config.py modified: roll/configs/generating_args.py modified: roll/distributed/scheduler/generate_scheduler.py modified: roll/pipeline/agentic/env_manager/step_env_manager.py modified: roll/pipeline/agentic/env_manager/traj_env_manager.py modified: roll/pipeline/base_worker.py modified: roll/pipeline/rlvr/actor_pg_worker.py modified: roll/pipeline/rlvr/actor_worker.py --- .../Configuration/infer_correction.md | 142 ++++++++++ .../agentic_webshop_infer_correction.yaml | 183 ++++++++++++ .../rlvr_infer_correction_config.yaml | 265 ++++++++++++++++++ .../run_agentic_pipeline_webshop.sh | 10 + .../run_rlvr_pipeline.sh | 5 + roll/configs/base_config.py | 58 +++- roll/configs/generating_args.py | 4 + .../scheduler/generate_scheduler.py | 1 + .../agentic/env_manager/step_env_manager.py | 4 +- .../agentic/env_manager/traj_env_manager.py | 11 +- roll/pipeline/base_worker.py | 165 ++++++++++- roll/pipeline/rlvr/actor_pg_worker.py | 10 +- roll/pipeline/rlvr/actor_worker.py | 23 +- 13 files changed, 857 insertions(+), 24 deletions(-) create mode 100644 docs_roll/docs/User Guides/Configuration/infer_correction.md create mode 100644 examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml create mode 100644 examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml create mode 100644 examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh create mode 100644 examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh diff --git a/docs_roll/docs/User Guides/Configuration/infer_correction.md b/docs_roll/docs/User Guides/Configuration/infer_correction.md new file mode 100644 index 000000000..ecdce6325 --- /dev/null +++ b/docs_roll/docs/User Guides/Configuration/infer_correction.md @@ -0,0 +1,142 @@ +# 训推差异修复 + + +## 简介 + +训推差异是由于RL训练过程中,训练器和生成器之间由于后端不同(vLLM vs SGLang vs FSDP vs Megatron),精度不同(FP8 vs FP16 vs BF16 vs FP32),形成了一种类似off-policy gap,会导致训练不稳定和策略崩溃。 + + +## 实现原理 + + +修复训推差异大致可分为两种方法(1)对训练器和生成器进行策略修正(2)使用infer_log_probs直接代替old_log_probs(trainer)进行PPO ratio计算。第二种方案比较直接,我们着重说明第一种方法。 + +### 对训练器和生成器进行策略修正 + +### IS权重矫正 +通过对训练器(old_log_probs)和生成器(infer_log_prob)之间进行重要性采样矫正,弥合训推差异。与off-policy算法类似,IS权重矫正可以区分token级别和sequence级别,只能选择一个。 + +### MASK过滤掉不符合条件的样本 +与IS权重修正不同的是,此方法对于超过阈值的样本直接进行mask遮掩,过滤掉不符合的样本。涉及的方法有(1)token级别:过滤掉不符合条件的token(2)灾难性token:过滤掉出现灾难性严重偏差的token的句子样本(3)sequence级别:对sequence进行IS计算,过滤掉不符合的句子样本(4)sequence级别,使用几何平均来计算IS权重,但指标也更为敏感 + + +## 关键参数配置 + +生成器是否返回infer_log_probs + +GeneratingArguments: + +```yaml +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: fp16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 +``` +当logprobs:大于0时,返回infer_log_probs + + +------- + +参数的配置在PPOConfig和中,关键配置参数如下: + +```yaml +infer_correction: true + +infer_is_mode: token # 可选 token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: true +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: true +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: sequence 可选None sequence geometric +infer_seq_mask_threshold_min: 0.1 +infer_seq_mask_threshold_max: 10 +``` + + +### infer_correction +- **含义**:控制是否启用训推差异修复机制。若启用,系统将使用 `infer_log_probs` 对策略梯度进行矫正。 +- **默认值**:`false` + +### infer_is_mode +- **含义**:指定重要性采样(IS)权重的计算粒度。 +- **可选值**: + - `"token"`:每个 token 独立计算 IS 权重 + - `"sequence"`:基于整个 response 序列的 log-ratio 求和后广播至所有 token + - `"none"`(或未设置):不应用 IS 加权,权重恒为 1 +- **默认值**:若未设置,默认为 `"token"` +- **注意**:不可同时使用多种模式,仅能选择其一。 + +### infer_is_threshold_min +- **含义**:IS 权重的下限阈值,用于裁剪过小的权重以控制方差。 +- **默认值**:`0.0` +- **建议**:通常保留为 `0.0`,以保持无偏性下界 + +### infer_is_threshold_max +- **含义**:IS 权重的上限阈值,防止极端大的权重主导梯度。 +- **默认值**:`2.0` +- **建议**:`"token"`级别推荐为 `1.5 ~ 5.0` `"sequence"`级别推荐为2.0 - 10.0 + +### enable_token_reject +- **含义**:是否启用 token 级别的样本拒绝机制。 +- **默认值**:`false` +- **作用**:结合 `infer_token_mask_threshold_min/max`,mask 掉 IS ratio 超出合法区间的 token。 + +### infer_token_mask_threshold_min +- **含义**:token 级 IS ratio(`old_log_probs / infer_log_probs` 的指数)的下限。 +- **默认值**:`0.0` +- **典型值**:`0.0`通常可设为1/max + +### infer_token_mask_threshold_max +- **含义**:遮掩token 级 IS ratio 的上限。 +- **默认值**:`2.0` +- **典型范围**:`1.5 ~ 5.0` + +### enable_catastrophic_reject +- **含义**:是否启用“灾难性偏差”检测并拒绝整句样本。 +- **默认值**:`false` +- **触发条件**:只要序列中存在任一 token 满足 `ratio < infer_catastrophic_threshold`,则整句被 mask。 + +### infer_catastrophic_threshold +- **含义**:灾难性拒绝的 ratio 下限阈值。 +- **默认值**:`1e-4` +- **解释**:当 `infer_log_probs` 远大于 `old_log_probs`(即生成器过于“自信”),导致 `ratio = exp(old - infer)` 极小 + +### enable_seq_reject +- **含义**:是否启用序列级别的拒绝机制,以及使用何种聚合方式。 +- **可选值**: + - `null` / `false`:禁用 + - `"sequence"`:使用 log-ratio **求和** 计算序列 IS ratio + - `"geometric"`:使用 log-ratio **平均**(等价于几何平均概率)计算序列 IS ratio +- **默认值**:`null` + +### infer_seq_mask_threshold_min +- **含义**:遮掩序列级 IS ratio 的下限。 +- **默认值**:`0.1` +- **典型值**:通常可设为1/max,当使用`"geometric"`时,最好强制设为1/max + + +### infer_seq_mask_threshold_max +- **含义**:遮掩序列级 IS ratio 的上限。 +- **默认值**:`10.0` +- **典型范围**:当使用`"sequence"`时,推荐`2.0 ~ 10.0`,但随着长度增加可适当放宽。当使用`"geometric"`时,推荐设置为1.0001 - 1.001 + + + +## 使用建议 + +1. 通常情况下,old_log_prob << infer_log_porb, 阈值的下限就比较重要了。并不建议使用sequence级别的IS或MASK + diff --git a/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file mode 100644 index 000000000..78fefee47 --- /dev/null +++ b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml @@ -0,0 +1,183 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_webshop_infer_correction" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_webshop +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +#track_with: swanlab +#tracker_kwargs: +# login_kwargs: +# api_key: your_api_key +# project: roll-agentic +# logdir: debug +# experiment_name: ${exp_name} +# tags: +# - roll +# - agentic +# - debug + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_webshop + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 64 +val_batch_size: 64 +sequence_length: 8192 + +reward_clip: 20 +advantage_clip: 0.2 # 0.1-0.3 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +max_grad_norm: 1.0 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +# infer correction + +infer_correction: true + +infer_is_mode: token #token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + + +enable_seq_reject: None # None sequence geometric +infer_seq_mask_threshold_min: 0.999 +infer_seq_mask_threshold_max: 1.001 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 8 + warmup_steps: 10 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + max_grad_norm: ${max_grad_norm} + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 1024 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + logprobs: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: sglang + strategy_config: + mem_fraction_static: 0.85 + load_format: auto + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.05 + num_env_groups: 8 + group_size: 8 + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + num_env_groups: 64 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [64] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +custom_envs: + WebShopEnv: + ${custom_env.WebShopEnv} \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file mode 100644 index 000000000..e2cbac8a7 --- /dev/null +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -0,0 +1,265 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-infer-correction-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + + +# infer correction +infer_correction: true + +infer_is_mode: token +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_rs_threshold_min: 0.0 +infer_token_rs_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: None + +# enable_seq_reject: sequence +# infer_seq_rs_threshold_min: 0.1 +# infer_seq_rs_threshold_max: 10 + +# enable_seq_reject: geometric +# infer_seq_rs_threshold_min: 0.999 +# infer_seq_rs_threshold_max: 1.001 + +validation: + data_args: + template: qwen2.5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2.5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 + data_args: + template: qwen2.5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2.5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + world_size: 8 + infer_batch_size: 1 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2.5 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file mode 100644 index 000000000..ff1a59402 --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Run `git submodule update --init --recursive` to init submodules before run this script. +set +x + + +pip install -r third_party/webshop-minimal/requirements.txt --trusted-host mirrors.aliyun.com --index-url https://mirrors.aliyun.com/pypi/simple/ +python -m spacy download en_core_web_sm + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_agentic_pipeline.py --config_path $CONFIG_PATH --config_name agentic_webshop_infer_correction diff --git a/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh new file mode 100644 index 000000000..6874f4f5e --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_infer_correction_config \ No newline at end of file diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 350f69cc5..617f672be 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -382,7 +382,63 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) - + + # trainer&rollout mismatch + infer_correction: bool = field( + default=False, + metadata={"help": "Whether to apply importance sampling correction during inference."} + ) + infer_is_mode: Literal["token", "sequence", "none"] = field( + default="token", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + ) + # Clipping thresholds (used in IS weighting) + infer_is_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum threshold for IS weight clipping. Recommended 0.0 for unbiased estimation."} + ) + infer_is_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum threshold for IS weight clipping."} + ) + # Token-level rejection + enable_token_reject: bool = field( + default=False, + metadata={"help": "Enable token-level rejection based on IS ratio thresholds."} + ) + infer_token_mask_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum IS ratio threshold for token rejection."} + ) + infer_token_mask_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum IS ratio threshold for token rejection."} + ) + # Catastrophic rejection (reject entire sequence if any token ratio is too small) + enable_catastrophic_reject: bool = field( + default=False, + metadata={"help": "Enable catastrophic rejection: reject entire sequence if any valid token has IS ratio below threshold."} + ) + infer_catastrophic_threshold: float = field( + default=1e-4, + metadata={"help": "Threshold below which a token triggers catastrophic rejection of its sequence."} + ) + # Sequence-level rejection + enable_seq_reject: Optional[Literal["sequence", "geometric",'None']] = field( + default=None, + metadata={"help": "Enable sequence-level rejection: 'sequence' uses sum of log-ratios, 'geometric' uses mean. None disables."} + ) + infer_seq_mask_threshold_min: float = field( + default=0.1, + metadata={"help": "Minimum IS ratio threshold for sequence rejection."} + ) + infer_seq_mask_threshold_max: float = field( + default=10.0, + metadata={"help": "Maximum IS ratio threshold for sequence rejection (typically larger than token-level)."} + ) + + + def __post_init__(self): super().__post_init__() diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 059aff4a9..cf09e2f21 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -58,6 +58,10 @@ class GeneratingArguments: default=None, metadata={"help": "Whether to include the stop strings in output text."}, ) + logprobs: Optional[int] = field( + default=None, + metadata={"help": "Whether return infer log-prob."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index c49b502c1..ab2630304 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -856,6 +856,7 @@ async def generate_one_request(self, data: DataProto): eos_token_id=eos_token_id, pad_token_id=pad_token_id, pad_to_seq_len=data.meta_info.get("pad_to_seq_len", True), + output_logprobs=response_data.meta_info.get("output_logprobs",None), ) request_repeat = data.repeat(repeat_times=len(output_tokens)) output.non_tensor_batch = request_repeat.non_tensor_batch diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 737910394..da2741982 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -85,6 +85,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): prompt_mask = torch.tensor(prompt_masks, dtype=torch.bool).unsqueeze(0) score_tensor = torch.tensor([0] * len(token_ids), dtype=torch.float).unsqueeze(0) score_tensor[0][-1] = history['reward'] + infer_logprobs = history["infer_logprobs"].flatten().unsqueeze(0) position_ids = attention_mask.cumsum(dim=-1) input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) @@ -93,7 +94,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) prompt_mask = pad_to_length(prompt_mask, length=self.pipeline_config.sequence_length, pad_value=0) score_tensor = pad_to_length(score_tensor, length=self.pipeline_config.sequence_length, pad_value=0) - + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) samples.append(DataProto( batch=TensorDict( { @@ -103,6 +104,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 050594636..bc65ff11a 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig from tensordict import TensorDict from transformers import PreTrainedTokenizer - +import json from roll.pipeline.agentic.llm_proxy import create_llm_proxy, BaseLLMProxy from roll.pipeline.agentic.env_manager.base_env_manager import RolloutCache, BaseEnvManager from roll.utils.env_action_limiter import get_global_limiter @@ -179,6 +179,7 @@ def step(self, llm_output: DataProto): self.rollout_cache.truncated = True self.rollout_cache.history[-1]['reward'] = reward self.rollout_cache.history[-1]['llm_response'] = responses[0] + self.rollout_cache.history[-1]['infer_logprobs'] = llm_output.batch['infer_logprobs'] if info is not None: self.rollout_cache.history[-1].update(info) @@ -293,13 +294,16 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): token_ids = [] prompt_masks = [] response_masks = [] + infer_logprobs = [] for items in self.rollout_cache.history: token_ids.extend(items["prompt_ids"]) token_ids.extend(items["response_ids"]) prompt_masks.extend([1] * len(items["prompt_ids"]) + [0] * len(items["response_ids"])) response_masks.extend([0] * len(items["prompt_ids"]) + [1] * len(items["response_ids"])) - + infer_logprobs.extend(items["infer_logprobs"].flatten().tolist()) + input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) @@ -316,6 +320,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]) @@ -323,6 +328,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): # TODO: move pad to pipeline input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) attention_mask = pad_to_length(attention_mask, length=self.pipeline_config.sequence_length, pad_value=0) position_ids = pad_to_length(position_ids, length=self.pipeline_config.sequence_length, pad_value=0) response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) @@ -336,6 +342,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }) lm_input.non_tensor_batch.update({ "env_ids": np.array([self.rollout_cache.env_id], dtype=object), diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 0987cfefa..f6e594fe7 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -2,7 +2,7 @@ import threading import time from typing import Union, Optional, Dict - +import numpy import ray import torch from codetiming import Timer @@ -25,6 +25,7 @@ postprocess_generate, GenerateRequestType, agg_loss, + masked_sum ) from roll.utils.offload_states import OffloadStateType from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching @@ -264,9 +265,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] + infer_log_probs = data.batch.get("infer_logprobs", None) log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] @@ -282,12 +285,18 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + pg_loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=pg_loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, - kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -329,9 +338,153 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), } + pg_metrics.update(infer_stats) return total_loss, pg_metrics - + + + def infer_correction( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 + 返回更新后的 pg_loss、mask 和详细统计信息。 + """ + # Step 0: Shape alignment + if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: + infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ + f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" + # Step 1: Compute log-ratio and ratio + log_ratio = old_log_probs - infer_log_probs # [B, T] + ratio = torch.exp(log_ratio) # [B, T] + # Step 2: Apply IS weighting strategy (optional) + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] + raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] + elif self.pipeline_config.infer_is_mode in (None, "none", ""): + raw_is_weight = torch.ones_like(ratio) + else: + raw_is_weight = torch.ones_like(ratio) + # Clamp to get final is_weight (used for loss) + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + # Step 3: Build rejection mask + original_valid = response_mask > 0.5 # [B, T], bool + keep_mask = original_valid.clone() + # (a) Token-level ratio reject + if getattr(self.pipeline_config, 'enable_token_reject', False): + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + keep_mask = keep_mask & (~token_reject) + # (b) Catastrophic reject + if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid + has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + keep_mask = keep_mask & (~has_catastrophic) + # (c) Sequence-level reject + if getattr(self.pipeline_config, 'enable_seq_reject', False): + if self.pipeline_config.enable_seq_reject=="sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_sum) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + elif self.pipeline_config.enable_seq_reject=="geometric": + log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_mean) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + # final_mask = keep_mask.float() + final_mask = keep_mask + # Step 4: Reweight policy loss + pg_loss = pg_loss * is_weight + # Step 5: Compute detailed stats over original_valid tokens + # Rejected mask + rejected_mask = original_valid & (~keep_mask) # [B, T] + # Clipped mask: only meaningful if IS weighting is active + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid + clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid + clipped_mask = clipped_low | clipped_high # [B, T] + else: + clipped_mask = torch.zeros_like(original_valid) # no clipping + # Compute fractions + def _compute_frac(mask_tensor): + return agg_loss( + loss_mat=mask_tensor.float(), + loss_mask=response_mask, + loss_agg_mode="token-mean" # force token-wise average + ).detach().item() + clip_frac = _compute_frac(clipped_mask) + reject_frac = _compute_frac(rejected_mask) + clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) + clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) + # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) + seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + ### kl metric + inferkl_orig = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=response_mask, # ← original mask + kl_penalty="kl" + ) + inferkl_final = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=final_mask, # ← after rejection + kl_penalty="kl" + ) + inferkl_orig_agg = agg_loss( + loss_mat=inferkl_orig, + loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + inferkl_final_agg = agg_loss( + loss_mat=inferkl_final, + loss_mask=final_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] + if valid_raw_is_weight.numel() > 0: + raw_is_mean = valid_raw_is_weight.mean().detach().item() + raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() + raw_is_min = valid_raw_is_weight.min().detach().item() + raw_is_max = valid_raw_is_weight.max().detach().item() + else: + # fallback if no valid tokens (rare edge case) + raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 + stats = { + "infer_correction/reject_frac": reject_frac, + "infer_correction/clip_frac": clip_frac, + "infer_correction/clip_and_reject_frac": clip_and_reject_frac, + "infer_correction/clip_or_reject_frac": clip_or_reject_frac, + "infer_correction/seq_reject_frac": seq_reject_frac, + "infer_correction/inferkl_orig": inferkl_orig_agg, + "infer_correction/inferkl_final": inferkl_final_agg, + "infer_correction/raw_is_mean": raw_is_mean, + "infer_correction/raw_is_std": raw_is_std, + "infer_correction/raw_is_min": raw_is_min, + "infer_correction/raw_is_max": raw_is_max, + } + return pg_loss, final_mask, stats @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): with Timer("do_checkpoint") as total_timer: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 813e582f5..b518acc2e 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -30,6 +30,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] @@ -76,7 +77,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -127,6 +134,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) + pg_metrics.updata(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index fbd7147bd..603ab4d7b 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -19,9 +19,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] - infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) - infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs - + infer_log_probs = data.batch.get("infer_logprobs", None) + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -57,8 +56,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() elif self.pipeline_config.importance_sampling == "seq": log_ratio = log_probs - old_log_probs masked_log_ratio = masked_mean(log_ratio, final_response_mask, dim=-1) @@ -73,7 +70,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -121,11 +124,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - } - loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -154,9 +152,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/sample_weights_max": sample_weights.max().detach().item(), **metrics, **loss_metric, - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) return total_loss, pg_metrics def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): From 3d291b19143e88295f5ff88a46b8a63ee14b3571 Mon Sep 17 00:00:00 2001 From: WeepCat Date: Thu, 4 Dec 2025 17:21:19 +0800 Subject: [PATCH 02/58] fixed_llm_proxy_mode_rollout_pipeline --- roll/pipeline/agentic/agentic_rollout_pipeline.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/roll/pipeline/agentic/agentic_rollout_pipeline.py b/roll/pipeline/agentic/agentic_rollout_pipeline.py index 4d4309ea8..d2c875923 100644 --- a/roll/pipeline/agentic/agentic_rollout_pipeline.py +++ b/roll/pipeline/agentic/agentic_rollout_pipeline.py @@ -63,8 +63,9 @@ def run(self): batch.meta_info = {"global_step": global_step} with Timer(name="rollout", logger=None) as rollout_timer: - batch.meta_info["is_offload_states"] = True - self.actor_infer.start_server(data=batch) + if self.use_policy_model: + batch.meta_info["is_offload_states"] = True + self.actor_infer.start_server(data=batch) batch = ray.get(self.rollout_scheduler.get_batch.remote(batch, self.pipeline_config.rollout_batch_size)) if batch is None: break From 6ca3d1041019b7fa769b523b9cb321a81142df78 Mon Sep 17 00:00:00 2001 From: WeepCat Date: Thu, 4 Dec 2025 17:02:12 +0800 Subject: [PATCH 03/58] fixed some typos --- .../docs/User Guides/Agentic/agentic_engineer_practice.md | 4 ++-- .../User Guides/Agentic/agentic_engineer_practice.md | 4 ++-- roll/pipeline/agentic/utils.py | 6 +++--- roll/pipeline/rlvr/utils.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs_roll/docs/User Guides/Agentic/agentic_engineer_practice.md b/docs_roll/docs/User Guides/Agentic/agentic_engineer_practice.md index 7277219a6..24810c2ac 100644 --- a/docs_roll/docs/User Guides/Agentic/agentic_engineer_practice.md +++ b/docs_roll/docs/User Guides/Agentic/agentic_engineer_practice.md @@ -233,7 +233,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): lm_input.non_tensor_batch["episode_score"] = np.array([episode_score], dtype=object) # Configure database field types - colummns_config = [ + columns_config = [ ["task_idx", "bigint"], ["model_name", "string"], ["stop_reason", "string"], @@ -241,7 +241,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): ["mode", "string"], ["save_content", "string"], ] - lm_input.meta_info["COLUMMNS_CONFIG"] = colummns_config + lm_input.meta_info["COLUMNS_CONFIG"] = columns_config return lm_input ``` diff --git a/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Agentic/agentic_engineer_practice.md b/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Agentic/agentic_engineer_practice.md index 40ad837c0..8f1a06143 100644 --- a/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Agentic/agentic_engineer_practice.md +++ b/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Agentic/agentic_engineer_practice.md @@ -234,7 +234,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): lm_input.non_tensor_batch["episode_score"] = np.array([episode_score], dtype=object) # 配置数据库字段类型 - colummns_config = [ + columns_config = [ ["task_idx", "bigint"], ["model_name", "string"], ["stop_reason", "string"], @@ -242,7 +242,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): ["mode", "string"], ["save_content", "string"], ] - lm_input.meta_info["COLUMMNS_CONFIG"] = colummns_config + lm_input.meta_info["COLUMNS_CONFIG"] = columns_config return lm_input ``` diff --git a/roll/pipeline/agentic/utils.py b/roll/pipeline/agentic/utils.py index 2f4f2b052..14f573d2e 100644 --- a/roll/pipeline/agentic/utils.py +++ b/roll/pipeline/agentic/utils.py @@ -204,13 +204,13 @@ def dump_rollout_trajectories(path, global_step, data: DataProto): """ Dumps rollout trajectories to persistent storage. - The data is written using a column-based configuration defined in COLUMMNS_CONFIG. + The data is written using a column-based configuration defined in COLUMNS_CONFIG. Each column is specified as a list [column_name, data_type], where: - column_name: string identifier for the column - data_type: data type specification ('bigint', 'string', 'double', etc.) Example configuration: - colummns_config = [ + columns_config = [ ['global_step', 'bigint'], ['id', 'string'], ['source', 'string'], @@ -220,7 +220,7 @@ def dump_rollout_trajectories(path, global_step, data: DataProto): if not path: return - columns_config: Optional[List] = data.meta_info.get("COLUMMNS_CONFIG", None) + columns_config: Optional[List] = data.meta_info.get("COLUMNS_CONFIG", None) if columns_config is None: return diff --git a/roll/pipeline/rlvr/utils.py b/roll/pipeline/rlvr/utils.py index 36a55466b..7173a1118 100644 --- a/roll/pipeline/rlvr/utils.py +++ b/roll/pipeline/rlvr/utils.py @@ -14,7 +14,7 @@ logger = get_logger() -COLUMMNS_CONFIG = [ +COLUMNS_CONFIG = [ ['global_step','bigint'], ['id','string'], ['source','string'], From 77c1e5d99c897e898ea1d9dc08a3b1e347f3679f Mon Sep 17 00:00:00 2001 From: millioniron Date: Sun, 7 Dec 2025 14:43:05 +0800 Subject: [PATCH 04/58] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E8=AE=A2?= =?UTF-8?q?=E4=BA=86=E6=95=B4=E4=B8=AA=E7=9A=84=E6=8E=92=E7=89=88=EF=BC=8C?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=E5=87=BA=E4=BA=86=E4=B8=80=E4=B8=AA=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E6=9B=B4=E5=8A=A0?= =?UTF-8?q?=E8=87=AA=E7=94=B1=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rlvr_infer_correction_config.yaml | 12 +- roll/configs/base_config.py | 6 +- roll/configs/generating_args.py | 2 +- roll/pipeline/base_worker.py | 169 +------- roll/pipeline/rlvr/actor_pg_worker.py | 22 +- roll/pipeline/rlvr/actor_worker.py | 22 +- roll/utils/infer_correction.py | 380 ++++++++++++++++++ 7 files changed, 438 insertions(+), 175 deletions(-) create mode 100644 roll/utils/infer_correction.py diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml index e2cbac8a7..1bbe3c8e8 100644 --- a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -80,8 +80,8 @@ infer_is_threshold_min: 0.0 infer_is_threshold_max: 2.0 # 1.5~5.0 enable_token_reject: false -infer_token_rs_threshold_min: 0.0 -infer_token_rs_threshold_max: 2.0 # 2~10 +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 enable_catastrophic_reject: false infer_catastrophic_threshold: 1e-4 @@ -89,12 +89,12 @@ infer_catastrophic_threshold: 1e-4 enable_seq_reject: None # enable_seq_reject: sequence -# infer_seq_rs_threshold_min: 0.1 -# infer_seq_rs_threshold_max: 10 +# infer_seq_mask_threshold_min: 0.1 +# infer_seq_mask_threshold_max: 10 # enable_seq_reject: geometric -# infer_seq_rs_threshold_min: 0.999 -# infer_seq_rs_threshold_max: 1.001 +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 validation: data_args: diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 617f672be..b2ef1ff28 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -388,9 +388,9 @@ class PPOConfig(BaseConfig): default=False, metadata={"help": "Whether to apply importance sampling correction during inference."} ) - infer_is_mode: Literal["token", "sequence", "none"] = field( - default="token", - metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + infer_is_mode: Literal["token", "sequence", "None"] = field( + default="None", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'None' (no IS weighting)."} ) # Clipping thresholds (used in IS weighting) infer_is_threshold_min: float = field( diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index cf09e2f21..8a540eb57 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -59,7 +59,7 @@ class GeneratingArguments: metadata={"help": "Whether to include the stop strings in output text."}, ) logprobs: Optional[int] = field( - default=None, + default=0, metadata={"help": "Whether return infer log-prob."}, ) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index f6e594fe7..62854c8b0 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -25,9 +25,10 @@ postprocess_generate, GenerateRequestType, agg_loss, - masked_sum + masked_sum, ) from roll.utils.offload_states import OffloadStateType +from roll.utils.infer_correction import InferCorrectionHandler from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -266,11 +267,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) + ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] infer_log_probs = data.batch.get("infer_logprobs", None) - log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) @@ -285,19 +286,26 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - pg_loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=pg_loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_penalty="k3") kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" ) @@ -341,150 +349,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_metrics.update(infer_stats) return total_loss, pg_metrics - - - def infer_correction( - self, - old_log_probs: torch.Tensor, - infer_log_probs: torch.Tensor, - response_mask: torch.Tensor, - pg_loss: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, dict]: - """ - 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 - 返回更新后的 pg_loss、mask 和详细统计信息。 - """ - # Step 0: Shape alignment - if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: - infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] - assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ - f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" - # Step 1: Compute log-ratio and ratio - log_ratio = old_log_probs - infer_log_probs # [B, T] - ratio = torch.exp(log_ratio) # [B, T] - # Step 2: Apply IS weighting strategy (optional) - if self.pipeline_config.infer_is_mode == "token": - raw_is_weight = ratio - elif self.pipeline_config.infer_is_mode == "sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] - raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] - elif self.pipeline_config.infer_is_mode in (None, "none", ""): - raw_is_weight = torch.ones_like(ratio) - else: - raw_is_weight = torch.ones_like(ratio) - # Clamp to get final is_weight (used for loss) - is_weight = raw_is_weight.clamp( - min=self.pipeline_config.infer_is_threshold_min, - max=self.pipeline_config.infer_is_threshold_max - ).detach() - # Step 3: Build rejection mask - original_valid = response_mask > 0.5 # [B, T], bool - keep_mask = original_valid.clone() - # (a) Token-level ratio reject - if getattr(self.pipeline_config, 'enable_token_reject', False): - ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max - ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min - token_reject = ratio_too_high | ratio_too_low - keep_mask = keep_mask & (~token_reject) - # (b) Catastrophic reject - if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): - catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid - has_catastrophic = catastrophic.any(dim=-1, keepdim=True) - keep_mask = keep_mask & (~has_catastrophic) - # (c) Sequence-level reject - if getattr(self.pipeline_config, 'enable_seq_reject', False): - if self.pipeline_config.enable_seq_reject=="sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_sum) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - elif self.pipeline_config.enable_seq_reject=="geometric": - log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_mean) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - # final_mask = keep_mask.float() - final_mask = keep_mask - # Step 4: Reweight policy loss - pg_loss = pg_loss * is_weight - # Step 5: Compute detailed stats over original_valid tokens - # Rejected mask - rejected_mask = original_valid & (~keep_mask) # [B, T] - # Clipped mask: only meaningful if IS weighting is active - if self.pipeline_config.infer_is_mode in ("token", "sequence"): - clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid - clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid - clipped_mask = clipped_low | clipped_high # [B, T] - else: - clipped_mask = torch.zeros_like(original_valid) # no clipping - # Compute fractions - def _compute_frac(mask_tensor): - return agg_loss( - loss_mat=mask_tensor.float(), - loss_mask=response_mask, - loss_agg_mode="token-mean" # force token-wise average - ).detach().item() - clip_frac = _compute_frac(clipped_mask) - reject_frac = _compute_frac(rejected_mask) - clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) - clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) - # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) - seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token - seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] - total_valid_seqs = seq_has_valid.sum().item() - rejected_seqs = seq_completely_rejected.sum().item() - seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 - - ### kl metric - inferkl_orig = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=response_mask, # ← original mask - kl_penalty="kl" - ) - inferkl_final = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=final_mask, # ← after rejection - kl_penalty="kl" - ) - inferkl_orig_agg = agg_loss( - loss_mat=inferkl_orig, - loss_mask=response_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - inferkl_final_agg = agg_loss( - loss_mat=inferkl_final, - loss_mask=final_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] - if valid_raw_is_weight.numel() > 0: - raw_is_mean = valid_raw_is_weight.mean().detach().item() - raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() - raw_is_min = valid_raw_is_weight.min().detach().item() - raw_is_max = valid_raw_is_weight.max().detach().item() - else: - # fallback if no valid tokens (rare edge case) - raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 - stats = { - "infer_correction/reject_frac": reject_frac, - "infer_correction/clip_frac": clip_frac, - "infer_correction/clip_and_reject_frac": clip_and_reject_frac, - "infer_correction/clip_or_reject_frac": clip_or_reject_frac, - "infer_correction/seq_reject_frac": seq_reject_frac, - "infer_correction/inferkl_orig": inferkl_orig_agg, - "infer_correction/inferkl_final": inferkl_final_agg, - "infer_correction/raw_is_mean": raw_is_mean, - "infer_correction/raw_is_std": raw_is_std, - "infer_correction/raw_is_min": raw_is_min, - "infer_correction/raw_is_max": raw_is_max, - } - return pg_loss, final_mask, stats + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): with Timer("do_checkpoint") as total_timer: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index b518acc2e..dd651b414 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl from roll.pipeline.rlvr.actor_worker import ActorWorker +from roll.utils.infer_correction import InferCorrectionHandler class ActorPGWorker(ActorWorker): @@ -30,11 +31,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() - infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", None) advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -77,13 +78,19 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -134,7 +141,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) - pg_metrics.updata(infer_stats) + + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 603ab4d7b..48e2fcf32 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -20,7 +21,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] infer_log_probs = data.batch.get("infer_logprobs", None) - + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -70,13 +71,21 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + + + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -124,6 +133,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), diff --git a/roll/utils/infer_correction.py b/roll/utils/infer_correction.py new file mode 100644 index 000000000..215203081 --- /dev/null +++ b/roll/utils/infer_correction.py @@ -0,0 +1,380 @@ +from typing import Literal, Optional, Tuple, Dict, Any +import torch + +class StatsCollector: + """统一收集诊断指标的类""" + def __init__(self, prefix: str = "infer_correction"): + self.prefix = prefix + self.stats: Dict[str, Any] = {} + self.tensor_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} + + def add(self, name: str, value: Any): + """添加标量指标""" + self.stats[f"{self.prefix}/{name}"] = value.item() if torch.is_tensor(value) else value + + def add_tensor_stat(self, name: str, tensor: torch.Tensor, mask: torch.Tensor): + """添加张量统计指标(延迟计算)""" + self.tensor_stats[name] = (tensor, mask) + + def compute_tensor_stats(self): + """严格遵循原始代码的数据移动策略""" + for name, (tensor, mask) in self.tensor_stats.items(): + # 1. 确保在同一设备上 + if tensor.device != mask.device: + mask = mask.to(tensor.device) + + # 2. 直接在原始代码风格中计算:先筛选,再移动到CPU + mask=mask.bool() + valid = tensor[mask] + + # 3. 严格按照原始代码逻辑处理 + if valid.numel() > 0: + # 关键:先detach()再item(),确保在CPU上计算 + valid_cpu = valid.detach().cpu() + self.add(f"{name}_mean", valid_cpu.mean().item()) + self.add(f"{name}_std", valid_cpu.std(unbiased=False).item() if valid_cpu.numel() > 1 else 0.0) + self.add(f"{name}_min", valid_cpu.min().item()) + self.add(f"{name}_max", valid_cpu.max().item()) + else: + self.add(f"{name}_mean", 0.0) + self.add(f"{name}_std", 0.0) + self.add(f"{name}_min", 0.0) + self.add(f"{name}_max", 0.0) + + self.tensor_stats.clear() + + def get_metrics(self) -> Dict[str, float]: + """获取所有指标""" + return self.stats.copy() + +class InferCorrectionHandler: + """处理重要性采样校正和样本拒绝的核心类""" + def __init__(self, pipeline_config: "PPOConfig"): + self.pipeline_config = pipeline_config + self.stats = StatsCollector() + + def __call__( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: + """ + 主入口:执行重要性采样校正和样本拒绝 + + Args: + old_log_probs: 历史策略的log概率 [B, T] + infer_log_probs: 生成时策略的log概率 [B, T] + response_mask: 有效token掩码 [B, T] + pg_loss: 原始策略梯度损失 [B, T] + + Returns: + weighted_loss: 重加权后的损失 + final_mask: 最终保留的token掩码 + metrics: 诊断指标字典 + """ + # 1. 对齐形状 + infer_log_probs = self._align_shapes(old_log_probs, infer_log_probs, response_mask) + + # 2. 计算IS权重 + ratio, raw_is_weight, is_weight = self._compute_is_weights(old_log_probs, infer_log_probs, response_mask) + + # 3. 收集基础统计 + self._collect_base_stats(ratio, response_mask) + + # 4. 应用拒绝策略 + keep_mask = response_mask.clone() + keep_mask = self._apply_token_rejection(ratio, keep_mask) + keep_mask = self._apply_catastrophic_rejection(ratio, keep_mask, response_mask) + keep_mask = self._apply_sequence_rejection(ratio, keep_mask, response_mask) + + # 5. 计算拒绝统计 + self._collect_rejection_stats(ratio, raw_is_weight, keep_mask, response_mask) + + # 6. 重加权损失 + weighted_loss = pg_loss * is_weight + + # 7. 计算KL指标 + self._compute_kl_metrics(old_log_probs, infer_log_probs, keep_mask, response_mask) + + # 8. 批量计算张量统计 + self.stats.compute_tensor_stats() + + return weighted_loss, keep_mask, self.stats.get_metrics() + + def _align_shapes( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """对齐log概率张量形状""" + if infer_log_probs.shape[1] == old_log_probs.shape[1] + 1: + infer_log_probs = infer_log_probs[:, 1:] + + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, ( + f"Shape mismatch: old_log_probs {old_log_probs.shape}, " + f"infer_log_probs {infer_log_probs.shape}, " + f"response_mask {response_mask.shape}" + ) + return infer_log_probs + + def _compute_is_weights( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 计算重要性采样权重 + + Returns: + ratio: 原始重要性比率 [B, T] + raw_is_weight: 未裁剪的IS权重 [B, T] + is_weight: 裁剪后的IS权重 [B, T] + """ + log_ratio = old_log_probs - infer_log_probs + ratio = torch.exp(log_ratio) + + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS:使用序列总log-ratio + log_ratio_sum = self._masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) + seq_ratio = torch.exp(log_ratio_sum) + raw_is_weight = seq_ratio.expand_as(ratio) + # 收集序列级统计 + self.stats.add_tensor_stat("seq_ratio", seq_ratio.squeeze(-1), torch.ones_like(seq_ratio.squeeze(-1), dtype=torch.bool)) + else: # "None" or any other value + raw_is_weight = torch.ones_like(ratio) + + # 裁剪IS权重 + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + + return ratio, raw_is_weight, is_weight + + def _collect_base_stats(self, ratio: torch.Tensor, response_mask: torch.Tensor): + """收集基础统计指标""" + self.stats.add_tensor_stat("token_ratio", ratio, response_mask) + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 1. 裁剪比例统计(现有代码) + clipped_low = ratio <= self.pipeline_config.infer_is_threshold_min + clipped_high = ratio >= self.pipeline_config.infer_is_threshold_max + clipped = clipped_low | clipped_high + self.stats.add("token_clip_low_frac", self._agg_loss(clipped_low.float(), response_mask)) + self.stats.add("token_clip_high_frac", self._agg_loss(clipped_high.float(), response_mask)) + self.stats.add("token_clip_frac", self._agg_loss(clipped.float(), response_mask)) + + # 2. 添加缺失的:裁剪后权重的分布统计 + if self.pipeline_config.infer_is_mode == "token": + # 重新计算裁剪后的权重 + is_weight = ratio.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ) + # 添加缺失的统计 + self.stats.add_tensor_stat("token_is_weight", is_weight, response_mask) + + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS权重已在_compute_is_weights中添加 + pass + + def _apply_token_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor + ) -> torch.Tensor: + """应用token级拒绝策略""" + if not self.pipeline_config.enable_token_reject: + return keep_mask + + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + + # 更新掩码:丢弃被拒绝的token + new_keep_mask = keep_mask & (~token_reject) + + # 收集统计 + self.stats.add("token_reject_low_frac", self._agg_loss(ratio_too_low.float(), keep_mask)) + self.stats.add("token_reject_high_frac", self._agg_loss(ratio_too_high.float(), keep_mask)) + + return new_keep_mask + + def _apply_catastrophic_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用灾难性拒绝策略""" + if not self.pipeline_config.enable_catastrophic_reject: + return keep_mask + + # 识别灾难性token + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & response_mask + + # 检查哪些序列包含灾难性token + seq_has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + + # 更新掩码:丢弃包含灾难性token的整个序列 + new_keep_mask = keep_mask & (~seq_has_catastrophic) + + # 收集统计 + catastrophic_token_frac = self._agg_loss(catastrophic.float(), response_mask) + self.stats.add("catastrophic_token_frac", catastrophic_token_frac) + + # 计算包含灾难性token的序列比例 + seq_has_valid = response_mask.any(dim=-1) + seq_has_catastrophic_flat = catastrophic.any(dim=-1) & seq_has_valid + catastrophic_seq_frac = ( + seq_has_catastrophic_flat.sum().float() / seq_has_valid.sum().float() + if seq_has_valid.sum() > 0 else 0.0 + ) + self.stats.add("catastrophic_seq_frac", catastrophic_seq_frac) + + return new_keep_mask + + def _apply_sequence_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用序列级拒绝策略""" + if self.pipeline_config.enable_seq_reject in (None, "None", "none"): + return keep_mask + + # 计算序列级比率 + if self.pipeline_config.enable_seq_reject == "sequence": + log_ratio_agg = self._masked_sum(torch.log(ratio), response_mask, dim=-1) + elif self.pipeline_config.enable_seq_reject == "geometric": + log_ratio_agg = self._masked_mean(torch.log(ratio), response_mask, dim=-1) + else: + return keep_mask + + seq_ratio = torch.exp(log_ratio_agg) + + # 识别要拒绝的序列 + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + + # 更新掩码 + new_keep_mask = keep_mask & (~seq_reject) + + # 收集统计 + seq_has_valid = response_mask.any(dim=-1) + total_valid_seqs = seq_has_valid.sum().item() + + seq_reject_low = seq_too_low & seq_has_valid + seq_reject_high = seq_too_high & seq_has_valid + + seq_reject_low_frac = seq_reject_low.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + seq_reject_high_frac = seq_reject_high.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + self.stats.add("seq_reject_low_frac", seq_reject_low_frac) + self.stats.add("seq_reject_high_frac", seq_reject_high_frac) + + return new_keep_mask + + def _collect_rejection_stats( + self, + ratio: torch.Tensor, + raw_is_weight: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """收集拒绝相关的统计指标""" + + # 计算被拒绝的token + rejected_mask = response_mask & (~keep_mask) + self.stats.add("reject_frac", self._agg_loss(rejected_mask.float(), response_mask)) + + + # 仅在序列拒绝启用时计算序列级拒绝率 + if self.pipeline_config.enable_seq_reject not in (None, "None", "none"): + seq_has_valid = response_mask.any(dim=-1) + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + self.stats.add("seq_reject_frac", seq_reject_frac) + else: + # 未启用时显式设为0.0 + self.stats.add("seq_reject_frac", 0.0) + + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 使用已计算的rejected_mask + clipped_mask = ((raw_is_weight <= self.pipeline_config.infer_is_threshold_min) | + (raw_is_weight >= self.pipeline_config.infer_is_threshold_max)) & response_mask + + clip_and_reject_frac = self._agg_loss((clipped_mask & rejected_mask).float(), response_mask) + clip_or_reject_frac = self._agg_loss((clipped_mask | rejected_mask).float(), response_mask) + + self.stats.add("token_clip_and_reject_frac", clip_and_reject_frac) + self.stats.add("token_clip_or_reject_frac", clip_or_reject_frac) + else: + # 关键:为未启用IS的情况提供默认值 + self.stats.add("token_clip_and_reject_frac", 0.0) + self.stats.add("token_clip_or_reject_frac", 0.0) + + def _compute_kl_metrics( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """计算KL散度指标""" + # 原始KL(所有有效token) + inferkl_orig = self._compute_approx_kl(infer_log_probs, old_log_probs, response_mask, kl_penalty="kl") + inferkl_orig_agg = self._agg_loss(inferkl_orig, response_mask) + self.stats.add("inferkl", inferkl_orig_agg) + + # 拒绝后KL(仅保留的token) + inferkl_final = self._compute_approx_kl(infer_log_probs, old_log_probs, keep_mask, kl_penalty="kl") + inferkl_final_agg = self._agg_loss(inferkl_final, keep_mask) + self.stats.add("inferkl_reject", inferkl_final_agg) + + # --- 辅助方法(使用已有工具函数)--- + def _compute_approx_kl( + self, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: torch.Tensor, + kl_penalty: str = "kl" + ) -> torch.Tensor: + """使用已有的compute_approx_kl函数计算近似KL散度""" + from roll.utils.functionals import compute_approx_kl + return compute_approx_kl( + log_probs=log_probs, + log_probs_base=log_probs_base, + action_mask=action_mask, + kl_penalty=kl_penalty + ) + + def _agg_loss(self, loss_mat: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """使用已有的agg_loss函数聚合损失""" + from roll.utils.functionals import agg_loss + return agg_loss( + loss_mat=loss_mat, + loss_mask=loss_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ) + + def _masked_sum(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_sum函数在掩码区域求和""" + from roll.utils.functionals import masked_sum + return masked_sum(tensor, mask, dim=dim) + + def _masked_mean(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_mean函数在掩码区域计算均值""" + from roll.utils.functionals import masked_mean + return masked_mean(tensor, mask, dim=dim) \ No newline at end of file From a0ea354017c106662cd5946845525816a5ba58ed Mon Sep 17 00:00:00 2001 From: "chengengru.cgr" Date: Tue, 28 Oct 2025 14:49:10 +0800 Subject: [PATCH 05/58] (fix): update math rule reward worker. --- .../rlvr/rewards/math_rule_reward_worker.py | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py index a74878840..7a7b8df19 100644 --- a/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/math_rule_reward_worker.py @@ -40,31 +40,44 @@ def __enter__(self): def __exit__(self, type, value, traceback): signal.alarm(0) -def _extract_after_last_end_think(response: str) -> str: +def _extract_after_last_end_think(response: str, prompt: str, start_think: str='', end_think: str='') -> str: """ 提取字符串中最后一个 "" 标签之后的所有文本。 - 校验逻辑: - - 如果字符串中包含开标签 "",直接返回空字符串。 - - 如果字符串中包含超过一个的闭标签 "",也直接返回空字符串。 + 校验逻辑会根据 prompt 的结尾而变化: + - (1) 如果 prompt 的结尾(去掉换行符后)是以 "" 结尾: + - response 中不允许包含开标签 ""。 + - response 中包含的闭标签 "" 不能超过一个。 + - 若不满足,则返回空字符串。 + - (2) 否则(prompt 不以 "" 结尾): + - response 中包含的闭标签 "" 不能超过一个。 + - 如果 response 中包含开标签 "",它必须出现在字符串的开头。 + - 若不满足,则返回空字符串。 - 如果校验通过,则执行原有逻辑: + 如果校验通过,则执行提取逻辑: 1. 优先按最后一个 '' 分割。 2. 如果找不到,则回退到按最后一个双换行符 '\n\n' 分割。 3. 如果都找不到,则返回空字符串。 Args: response (str): 输入的完整文本。 + prompt (str): 用于生成 response 的提示文本。 Returns: str: 提取出的文本块(已去除首尾空格),或空字符串。 """ - # 如果检测到 "" 或超过一个 "",直接返回空字符串 - if "" in response or response.count('') > 1: - return "" - + # 检查 prompt 是否以 结尾 + is_prompt_ending_with_think = prompt.rstrip('\n').endswith(start_think) + + if is_prompt_ending_with_think: + if start_think in response or response.count(end_think) > 1: + return "" + else: + if response.count(end_think) > 1 or start_think in response and not response.startswith(start_think): + return "" + # 1. 优先尝试按 '' 分割 - _before_think, sep_think, after_think = response.rpartition('') + _before_think, sep_think, after_think = response.rpartition(end_think) if sep_think: # 如果找到了 '',则返回它后面的部分,并清理首尾空格 @@ -79,10 +92,10 @@ def _extract_after_last_end_think(response: str) -> str: # 3. 如果连 '\n\n' 都没找到,则返回空字符串 return "" -def _hf_verify_math_sample(response, answer, result): +def _hf_verify_math_sample(response, answer, result, prompt): try: # 在解析之前,先对模型的原始输出进行预处理 - cleaned_response = _extract_after_last_end_think(response) + cleaned_response = _extract_after_last_end_think(response, prompt) """ --- `parse` 函数完整参数介绍与使用建议 --- `parse` 函数用于从文本中提取并解析数学答案,其主要参数如下: @@ -133,13 +146,13 @@ def _hf_verify_math_sample(response, answer, result): result.append((False, "", "")) -def hf_verify_math_sample(answer_a, answer_b, timeout_sec=5.0): +def hf_verify_math_sample(answer_a, answer_b, prompt, timeout_sec=5.0): with multiprocessing.Manager() as manager: result = manager.list() p = multiprocessing.Process( target=_hf_verify_math_sample, - args=(answer_a, answer_b, result) + args=(answer_a, answer_b, result, prompt) ) p.start() @@ -219,13 +232,18 @@ def compute_rewards(self, data: DataProto): format_rewards = [] response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=False) - for response, answer in zip(response_text_list, data.non_tensor_batch["ground_truth"]): + prompt_text_list = self.tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=False) + for response, answer, prompt in zip(response_text_list, data.non_tensor_batch["ground_truth"], prompt_text_list): + + prompt = prompt.replace("<|endoftext|>", "").replace("", "") response = response.replace("<|endoftext|>", "").replace("", "") + # self.logger.info(json.dumps({ + # "prompt": prompt}, ensure_ascii=False)) try: with timeout(5): correct, extracted_ground_truth, extracted_response = hf_verify_math_sample( - response, f"${answer}$" + response, f"${answer}$", prompt ) log_data = { From 055ef9b791e1b1d86459d1984b4bf7b1150dffdf Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Wed, 29 Oct 2025 14:00:40 +0800 Subject: [PATCH 06/58] (feat): set RAY_CGRAPH_get_timeout=600. --- roll/platforms/cpu.py | 1 + roll/platforms/cuda.py | 1 + roll/platforms/npu.py | 1 + 3 files changed, 3 insertions(+) diff --git a/roll/platforms/cpu.py b/roll/platforms/cpu.py index ec4fc6892..3149938d3 100644 --- a/roll/platforms/cpu.py +++ b/roll/platforms/cpu.py @@ -30,6 +30,7 @@ def get_custom_env_vars(cls) -> dict: # too long. "RAY_get_check_signal_interval_milliseconds": "1", "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", + "RAY_CGRAPH_get_timeout": '600', } return env_vars diff --git a/roll/platforms/cuda.py b/roll/platforms/cuda.py index 2d9911e03..fcf2c289a 100644 --- a/roll/platforms/cuda.py +++ b/roll/platforms/cuda.py @@ -32,6 +32,7 @@ def get_custom_env_vars(cls) -> dict: env_vars = { # "RAY_DEBUG": "legacy" "RAY_get_check_signal_interval_milliseconds": "1", + "RAY_CGRAPH_get_timeout": '600', "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", "TORCHINDUCTOR_COMPILE_THREADS": "2", "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", diff --git a/roll/platforms/npu.py b/roll/platforms/npu.py index 509d88755..c3d0b6a94 100644 --- a/roll/platforms/npu.py +++ b/roll/platforms/npu.py @@ -39,6 +39,7 @@ def get_custom_env_vars(cls) -> dict: # too long. "RAY_get_check_signal_interval_milliseconds": "1", "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", + "RAY_CGRAPH_get_timeout": '600', } return env_vars From 29a7610e99874b50ced48ca7e74d1c276e9fde60 Mon Sep 17 00:00:00 2001 From: "tianhe.lzd" Date: Wed, 29 Oct 2025 19:21:38 +0800 Subject: [PATCH 07/58] (fix): vllm 0.11.0 import --- roll/third_party/vllm/vllm_0_11_0/ray_distributed_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roll/third_party/vllm/vllm_0_11_0/ray_distributed_executor.py b/roll/third_party/vllm/vllm_0_11_0/ray_distributed_executor.py index 483eae457..1b9288825 100644 --- a/roll/third_party/vllm/vllm_0_11_0/ray_distributed_executor.py +++ b/roll/third_party/vllm/vllm_0_11_0/ray_distributed_executor.py @@ -11,7 +11,7 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_distributed_executor import RayDistributedExecutor, RayWorkerMetaData from vllm.executor.ray_utils import RayWorkerWrapper -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.v1.outputs import SamplerOutput from vllm.platforms import current_platform from vllm.ray.ray_env import get_env_vars_to_copy from vllm.utils import make_async, get_ip, get_distributed_init_method, get_open_port From bcc5818c854a8d741e9d0d15d5b29a0ff7d52e76 Mon Sep 17 00:00:00 2001 From: "scott.lxy" Date: Wed, 5 Nov 2025 11:26:52 +0800 Subject: [PATCH 08/58] (fix): fix train infer ratio/diff mean & add train infer ratio/diff token/seq mask & add rollout importance sampling. --- roll/pipeline/rlvr/actor_worker.py | 46 +++++++++++++++++++++++++++--- roll/pipeline/rlvr/rlvr_config.py | 17 +++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index fbd7147bd..47d8a6b51 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -55,14 +55,40 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl" ) + train_infer_ratio = (old_log_probs - infer_log_probs).exp() + train_infer_diff = old_log_probs.exp() - infer_log_probs.exp() + train_infer_ratio_seq = masked_mean(old_log_probs - infer_log_probs, response_mask, dim=-1).exp().unsqueeze(-1).expand_as(train_infer_ratio) + train_infer_diff_seq = masked_mean(old_log_probs.exp() - infer_log_probs.exp(), response_mask, dim=-1).unsqueeze(-1).expand_as(train_infer_diff) + + train_infer_ratio_mask_mean = 1.0 + train_infer_diff_mask_mean = 1.0 + train_infer_ratio_seq_mask_mean = 1.0 + train_infer_diff_seq_mask_mean = 1.0 + + if self.pipeline_config.train_infer_ratio_mask: + train_infer_ratio_mask = (train_infer_ratio <= self.pipeline_config.train_infer_ratio_threshold_high).float() * (train_infer_ratio >= self.pipeline_config.train_infer_ratio_threshold_low).float() + train_infer_ratio_mask_mean = masked_mean(train_infer_ratio_mask, final_response_mask, dim=-1).mean().detach().item() + final_response_mask = final_response_mask * train_infer_ratio_mask + if self.pipeline_config.train_infer_diff_mask: + train_infer_diff_mask = (train_infer_diff <= self.pipeline_config.train_infer_diff_threshold_high).float() * (train_infer_diff >= self.pipeline_config.train_infer_diff_threshold_low).float() + train_infer_diff_mask_mean = masked_mean(train_infer_diff_mask, final_response_mask, dim=-1).mean().detach().item() + final_response_mask = final_response_mask * train_infer_diff_mask + + if self.pipeline_config.train_infer_ratio_seq_mask: + train_infer_ratio_seq_mask = (train_infer_ratio_seq <= self.pipeline_config.train_infer_ratio_seq_threshold_high).float() * (train_infer_ratio_seq >= self.pipeline_config.train_infer_ratio_seq_threshold_low).float() + train_infer_ratio_seq_mask_mean = masked_mean(train_infer_ratio_seq_mask, final_response_mask, dim=-1).mean().detach().item() + final_response_mask = final_response_mask * train_infer_ratio_seq_mask + if self.pipeline_config.train_infer_diff_seq_mask: + train_infer_diff_seq_mask = (train_infer_diff_seq <= self.pipeline_config.train_infer_diff_seq_threshold_high).float() * (train_infer_diff_seq >= self.pipeline_config.train_infer_diff_seq_threshold_low).float() + train_infer_diff_seq_mask_mean = masked_mean(train_infer_diff_seq_mask, final_response_mask, dim=-1).mean().detach().item() + final_response_mask = final_response_mask * train_infer_diff_seq_mask + if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() elif self.pipeline_config.importance_sampling == "seq": log_ratio = log_probs - old_log_probs masked_log_ratio = masked_mean(log_ratio, final_response_mask, dim=-1) - ratio = masked_log_ratio.exp().unsqueeze(-1).expand_as(log_ratio) + ratio = masked_log_ratio.exp().unsqueeze(-1).expand_as(log_ratio) pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip @@ -70,10 +96,15 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages loss = -torch.min(surr1, surr2) + if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) + if self.pipeline_config.use_rollout_importance_sampling_ratio: + rollout_importance_sampling_clip = (train_infer_ratio > self.pipeline_config.rollout_importance_sampling_ratio_upper_bound).float() + loss = train_infer_ratio.clamp(0, self.pipeline_config.rollout_importance_sampling_ratio_upper_bound) * loss + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -124,8 +155,12 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): train_infer_prob_metric = { "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), + "actor/train_infer_ratio_mask_mean": train_infer_ratio_mask_mean, + "actor/train_infer_diff_mask_mean": train_infer_diff_mask_mean, + "actor/train_infer_ratio_seq_mask_mean": train_infer_ratio_seq_mask_mean, + "actor/train_infer_diff_seq_mask_mean": train_infer_diff_seq_mask_mean, } - + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -137,6 +172,9 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode, loss_scale=loss_scale).detach().item(), } + if self.pipeline_config.use_rollout_importance_sampling_ratio: + loss_metric["actor/rollout_importance_sampling_clip"] = rollout_importance_sampling_clip.mean().detach().item() + pg_metrics = { "actor/pg_loss": original_pg_loss.detach().item(), "actor/weighted_pg_loss": weighted_pg_loss.detach().item(), diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index ef0d07c4d..cb837102e 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -149,6 +149,23 @@ class RLVRConfig(PPOConfig): importance_sampling: Literal["token", "seq"] = ( field(default="token", metadata={"help": "policy importance sampling"}) ) + use_rollout_importance_sampling_ratio: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level loss weight"}) + rollout_importance_sampling_ratio_upper_bound: float = field(default=1.2) + + train_infer_ratio_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level response mask"}) + train_infer_ratio_threshold_low: float = field(default=0.8) + train_infer_ratio_threshold_high: float = field(default=1.2) + train_infer_diff_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as token-level response mask"}) + train_infer_diff_threshold_low: float = field(default=-0.2) + train_infer_diff_threshold_high: float = field(default=0.2) + + train_infer_ratio_seq_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as sequence-level response mask"}) + train_infer_ratio_seq_threshold_low: float = field(default=0.8) + train_infer_ratio_seq_threshold_high: float = field(default=1.2) + train_infer_diff_seq_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as sequence-level response mask"}) + train_infer_diff_seq_threshold_low: float = field(default=-0.2) + train_infer_diff_seq_threshold_high: float = field(default=0.2) + val_greedy: bool = field(default=False, metadata={"help": "Use greedy for validation"}) val_n_sample: int = field(default=1, metadata={"help": "Number of samples for validation"}) max_len_mask: bool = field(default=False) From a252c2ecaf11eef1fa256b6b9e9718e9a4f49782 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Wed, 5 Nov 2025 14:05:22 +0800 Subject: [PATCH 09/58] (feat): support vllm beam_search. --- .../docs/User Guides/Configuration/vllm.md | 14 + .../current/User Guides/Configuration/vllm.md | 15 ++ roll/configs/generating_args.py | 4 +- .../scheduler/generate_scheduler.py | 73 +++++- .../distributed/scheduler/reward_scheduler.py | 1 - roll/distributed/strategy/vllm_strategy.py | 118 +++++---- roll/pipeline/rlvr/rlvr_pipeline.py | 21 +- .../scheduler/test_protocol_padding.py | 163 ++++++++++++ .../test_vllm_strategy_beam_search.py | 242 ++++++++++++++++++ 9 files changed, 584 insertions(+), 67 deletions(-) create mode 100644 tests/distributed/scheduler/test_protocol_padding.py create mode 100644 tests/distributed/strategy/test_vllm_strategy_beam_search.py diff --git a/docs_roll/docs/User Guides/Configuration/vllm.md b/docs_roll/docs/User Guides/Configuration/vllm.md index a70773820..7d824ccff 100644 --- a/docs_roll/docs/User Guides/Configuration/vllm.md +++ b/docs_roll/docs/User Guides/Configuration/vllm.md @@ -74,6 +74,20 @@ In the configuration example, we can see: This design allows different components to choose the most suitable inference engine according to their needs. +### beam_search Configuration +RLVRPipeline supports vllm beam_search generation method, configured as follows: +```yaml +generate_opt_level: 0 # Degrades to batch_generate generation method, generate_opt_level=1 is prompt-level parallel method +num_return_sequences_in_group: 8 +actor_infer: + generating_args: + num_beams: ${num_return_sequences_in_group} + num_return_sequences: ${num_return_sequences_in_group} +``` +Note: +- generating_args.num_beams and generating_args.num_return_sequences must be set to the same value. +- The generating_args configuration in validate is also configured in the same way. + ## Performance Optimization Recommendations 1. **Memory Management**: diff --git a/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Configuration/vllm.md b/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Configuration/vllm.md index 84543d004..f2cc4574e 100644 --- a/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Configuration/vllm.md +++ b/docs_roll/i18n/zh-Hans/docusaurus-plugin-content-docs/current/User Guides/Configuration/vllm.md @@ -74,6 +74,21 @@ actor_infer: 这种设计允许不同组件根据其需求选择最适合的推理引擎。 +### beam_search 配置方式 +RLVRPipeline 支持vllm beam_search 的生成方式,配置方式如下: +```yaml +generate_opt_level: 0 # 退化为batch_generate生成方式,generate_opt_level=1是prompt粒度并行方式 +num_return_sequences_in_group: 8 +actor_infer: + generating_args: + num_beams: ${num_return_sequences_in_group} + num_return_sequences: ${num_return_sequences_in_group} +``` +注意: +- generating_args.num_beams 和 generating_args.num_return_sequences 必须设置为相同的值。 +- validate中配置generating_args也是相同的方式。 + + ## 性能优化建议 1. **内存管理**: diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 059aff4a9..68cf88d17 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -71,4 +71,6 @@ def to_dict(self) -> Dict[str, Any]: def __post_init__(self): if self.stop_strings is not None: - self.stop_strings = list(self.stop_strings) \ No newline at end of file + self.stop_strings = list(self.stop_strings) + if self.num_beams > 1: + self.num_return_sequences = self.num_beams diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index c49b502c1..8b6c5c06b 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -24,8 +24,9 @@ import os from roll.distributed.executor.cluster import Cluster -from roll.distributed.scheduler.protocol import DataProto, collate_fn -from roll.models.model_providers import default_tokenizer_provider +from roll.distributed.scheduler.protocol import DataProto, collate_fn, pad_dataproto_to_divisor, unpad_dataproto +from roll.distributed.scheduler.reward_scheduler import RewardScheduler +from roll.models.model_providers import default_tokenizer_provider, default_processor_provider from roll.utils.constants import RAY_NAMESPACE from roll.utils.functionals import ( postprocess_generate, @@ -365,6 +366,7 @@ def __init__(self, pipeline_config=None): self.exception_queue = queue.Queue() self.running = False self.dataset_epoch = 0 + self.reward_scheduler = RewardScheduler() # Flow control measures. max_running_requests limits the maximum number of concurrent requests for each dp. # max_additional_running_prompts limits the number of prompts running simultaneously to avoid excessive consumption of prompts. @@ -484,6 +486,54 @@ def reset_status(self): mininterval=int(self.batch_size * 0.1) + 1, ) + def get_batch_opt_level_0(self, data: DataProto, batch_size: int) -> DataProto: + completed_data: List[DataProto] = [] + query_use_count = 0 + + while len(completed_data) < batch_size: + data_item_list = [self.get_next_dataset_item() for _ in range(batch_size)] + collect_data = self.collect_fn(data_item_list) + request_data: DataProto = DataProto.from_single_dict(collect_data, meta_info=data.meta_info) + request_data.batch["prompt_id"] = torch.arange(request_data.batch.batch_size[0], device=request_data.batch.device) + + gen_batch = request_data.pop(batch_keys=["input_ids", "attention_mask", "position_ids"]) + gen_batch.meta_info = request_data.meta_info + num_return_sequences = self.generation_config["num_return_sequences"] + request_data = request_data.repeat(repeat_times=num_return_sequences) + + # Pad gen_batch to be divisible by dp_size to avoid errors + gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, self.actor_cluster.dp_size) + batch: DataProto = self.actor_cluster.generate(gen_batch_padded) + batch = unpad_dataproto(batch, pad_size * num_return_sequences) + + batch.union(other=request_data) + batch.rename(old_keys="prompt_id", new_keys="origin_prompt_id") + batch_rewards = self.reward_scheduler.compute_rewards(data=batch, reward_clusters=self.reward_clusters, pipeline_config=self.pipeline_config) + metrics = batch.meta_info.pop("metrics", {}) + metrics.update(batch_rewards.meta_info.pop("metrics", {})) + + batch.union(other=batch_rewards) + + batch.meta_info["metrics"] = metrics + batch_grouped: Dict[str, DataProto] = batch.group_by("origin_prompt_id") + for prompt_id, batch_item in batch_grouped.items(): + if self.query_filter_fn([batch_item], self.pipeline_config): + completed_data.append(batch_item) + else: + self.query_filter_count += 1 + query_use_count += batch_size + + batch = DataProto.concat(completed_data[: self.batch_size]) + batch.meta_info["metrics"] = { + f"scheduler/query_filter_count": self.query_filter_count, + f"scheduler/response_filter_count": self.response_filter_count, + f"scheduler/collect_query_count": self.batch_size, + f"scheduler/query_use_count": query_use_count, + } + self.reset_status() + return batch + + def get_batch(self, data: DataProto, batch_size: int) -> DataProto: """ 从dataset里,按给定策略sample batch @@ -493,8 +543,12 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: self.batch_size = batch_size self.reset_status() self.running = True - prompt_id_counter = itertools.count() self.generation_config = copy.deepcopy(data.meta_info["generation_config"]) + + if self.pipeline_config.generate_opt_level == 0: + return self.get_batch_opt_level_0(data, batch_size) + + prompt_id_counter = itertools.count() num_return_sequences = self.generation_config["num_return_sequences"] while True: if ( @@ -538,11 +592,7 @@ def get_batch(self, data: DataProto, batch_size: int) -> DataProto: if int(os.environ.get("REPORT_LENGTH_AND_REWARDS", "0")): self.prompt_id_2_hash_str[prompt_id] = base64.urlsafe_b64encode(prompt_digest).decode().rstrip('=') # prompt_id 对应 unique prompt self.requests_buffers[req.meta_info["request_id"]] = req - ray.get( - self.actor_cluster.workers[dp_rank].add_request.remote( - command=GenerateRequestType.ADD, data=req - ) - ) + self.actor_cluster.workers[dp_rank].add_request.remote(command=GenerateRequestType.ADD, data=req) req.meta_info.pop("response_callback_fn") self.load_balance_coordinator[dp_rank] += 1 self.dp_fetch_count[dp_rank] += 1 @@ -605,6 +655,10 @@ def report_response(self, data: DataProto): # call reward # reward worker得能支持单条数据计算, dynamic sampling对需要batch计算reward的需要注意... # 多域的时候,llm as judge, 需要单独为reward worker分配gpu + + # set rollout id + batch.non_tensor_batch["rollout_id"] = np.array([str(uuid.uuid4()) for _ in range(output_count)], dtype=object) + rewards: DataProto = ray.get(reward_worker.compute_rewards.remote(batch)) batch.union(rewards) @@ -844,6 +898,9 @@ async def generate_one_request(self, data: DataProto): pad_token_id = response_data.meta_info["pad_token_id"] output_token_ids = response_data.meta_info["output_token_ids"] output_tokens = [torch.tensor(token_ids) for token_ids in output_token_ids] + + output_logprobs = response_data.meta_info.get("output_logprobs", None) + output_tensor = pad_sequence(output_tokens, batch_first=True, padding_value=pad_token_id) output_tensor = concatenate_input_and_output( input_ids=data.batch["input_ids"], output_ids=output_tensor, num_return_sequences=len(output_tokens) diff --git a/roll/distributed/scheduler/reward_scheduler.py b/roll/distributed/scheduler/reward_scheduler.py index 120bb78d6..5619539e4 100644 --- a/roll/distributed/scheduler/reward_scheduler.py +++ b/roll/distributed/scheduler/reward_scheduler.py @@ -12,7 +12,6 @@ logger = get_logger() -@ray.remote class RewardScheduler: """ reward 服务化和generate不同, request接口: diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 6b7be2854..f44dd29f8 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -13,8 +13,9 @@ from torch.nn.utils.rnn import pad_sequence from transformers import set_seed from vllm import RequestOutput, SamplingParams +from vllm.beam_search import BeamSearchOutput from vllm.lora.request import LoRARequest -from vllm.sampling_params import RequestOutputKind +from vllm.sampling_params import RequestOutputKind, BeamSearchParams from vllm.utils import random_uuid from roll.distributed.executor.worker import Worker @@ -124,6 +125,18 @@ def op_compute_log_probs(self, logits: torch.Tensor, input_ids: torch.Tensor, at pass def generate(self, batch: DataProto, generation_config) -> torch.Tensor: + # Check if beam search is requested + if self._should_use_beam_search(generation_config): + return self._generate_with_beam_search(batch, generation_config) + else: + return self._generate_standard(batch, generation_config) + + def _should_use_beam_search(self, generation_config) -> bool: + """Check if beam search should be used based on generation_config.""" + return generation_config.get("num_beams", 1) > 1 or generation_config.get("use_beam_search", False) + + def _generate_standard(self, batch: DataProto, generation_config) -> torch.Tensor: + """Standard generate method for non-beam search cases.""" sampling_params = create_sampling_params_for_vllm(gen_kwargs=generation_config) input_ids = batch.batch["input_ids"] # (bs, prompt_length) @@ -170,6 +183,63 @@ def generate(self, batch: DataProto, generation_config) -> torch.Tensor: return output + def _generate_with_beam_search(self, batch: DataProto, generation_config) -> torch.Tensor: + """Generate using beam search method.""" + # Create beam search parameters + beam_params = BeamSearchParams( + beam_width=generation_config.get("num_beams", 1), + max_tokens=generation_config.get("max_new_tokens", 50), + temperature=generation_config.get("temperature", 0.0), + ignore_eos=generation_config.get("ignore_eos", False), + length_penalty=generation_config.get("length_penalty", 1.0), + include_stop_str_in_output=generation_config.get("include_stop_str_in_output", False), + ) + + input_ids = batch.batch["input_ids"] # (bs, prompt_length) + attention_mask = batch.batch["attention_mask"] # left-padded attention_mask + + # Prepare prompts for beam_search + if "multi_modal_data" in batch.non_tensor_batch: + # For multimodal data, we need to handle it differently + # This is a simplified approach - may need refinement based on actual multimodal format + prompts = batch.non_tensor_batch["multi_modal_data"] + else: + # Convert to token lists format expected by beam_search + token_lists = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) + # Convert to TokensPrompt format expected by vLLM beam_search + prompts = [{"prompt_token_ids": token_ids} for token_ids in token_lists] + + # Call beam_search method + beam_search_outputs = self.model.beam_search( + prompts=prompts, + params=beam_params, + ) + + generated_token_ids = [] + token_ids = [prompt['prompt_token_ids'] for prompt in prompts] + for batch_idx, output in enumerate(beam_search_outputs): + # Each output contains beam_width sequences + for beam_idx, sequence in enumerate(output.sequences): + # Get prompt length for this input + prompt_length = len(token_ids[batch_idx]) + # Extract only the generated tokens (exclude prompt) + generated_tokens = sequence.tokens[prompt_length:] + generated_token_ids.append(torch.tensor(generated_tokens, device=input_ids.device)) + + # Pad the sequences + output_ids = pad_sequence(generated_token_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) + + # Concatenate input and output + output = concatenate_input_and_output( + input_ids=input_ids, + output_ids=output_ids, + num_return_sequences=beam_params.beam_width + ) + + return output + def process_vllm_output(self, vllm_outputs: List[RequestOutput], request_complete_callback, collect_unfinished=False): # 转成response id, request_complete_callback report_request_ids = [] @@ -383,20 +453,6 @@ def create_sampling_params_for_vllm(gen_kwargs): assert gen_kwargs["num_return_sequences"] == 1, ( "fetch_output only supports num_return_sequences=1 or output_kind=FINAL" ) - - if gen_kwargs["num_beams"] > 1: - return SamplingParams( - max_tokens=gen_kwargs["max_new_tokens"], - stop_token_ids=gen_kwargs["eos_token_id"], - repetition_penalty=gen_kwargs["repetition_penalty"], - n=gen_kwargs["num_return_sequences"], - best_of=gen_kwargs["num_beams"], - use_beam_search=True, - stop=gen_kwargs["stop_strings"], - logprobs=gen_kwargs.get("logprobs", 0), - output_kind=output_kind, - include_stop_str_in_output=gen_kwargs.get("include_stop_str_in_output", True), - ) return SamplingParams( max_tokens=gen_kwargs["max_new_tokens"], temperature=gen_kwargs["temperature"], @@ -410,35 +466,3 @@ def create_sampling_params_for_vllm(gen_kwargs): output_kind=output_kind, include_stop_str_in_output=gen_kwargs.get("include_stop_str_in_output", True), ) - - -def compare_sampling_params(params1: SamplingParams, params2: SamplingParams) -> bool: - # 只比较采样参数的配置 - param_attrs = [ - "temperature", - "top_p", - "top_k", - "max_tokens", - "n", - "stop_token_ids", - "presence_penalty", - "frequency_penalty", - "repetition_penalty", - "min_p", - "best_of", - "stop", - "ignore_eos", - "use_beam_search", - "best_of", - "use_beam_search", - ] - - # 比较每个采样参数 - for attr in param_attrs: - if hasattr(params1, attr) and hasattr(params2, attr): - val1 = getattr(params1, attr) - val2 = getattr(params2, attr) - if val1 != val2: - print(f"采样参数 {attr} 不同: {val1} != {val2}") - return False - return True diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 4f702f0b6..d4f35b2c2 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -433,7 +433,7 @@ def run(self): actor_train_timer = _Timer(window_size=5) pre_step_total_time = 0 - if self.pipeline_config.async_pipeline: + if self.pipeline_config.async_pipeline and self.pipeline_config.generate_opt_level == 1: for reward_cluster in self.rewards.values(): reward_cluster.load_states() @@ -460,7 +460,7 @@ def run(self): self.actor_train.offload_states(blocking=True) with Timer(name="step_stop_server", logger=None) as step_stop_server_timer: - if self.pipeline_config.async_pipeline and not first_step: + if self.pipeline_config.async_pipeline and not first_step and self.pipeline_config.generate_opt_level == 1: scheduler_refs = [] for scheduler in self.generate_schedulers.values(): scheduler_refs.append(scheduler.pause_sampling.remote(data=batch)) @@ -487,7 +487,9 @@ def run(self): Timer(name="step_generate", logger=None) as step_generate_timer, ): domain_batches = {} - self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info)) + if self.pipeline_config.generate_opt_level == 1: + self.actor_infer.start_server(data=DataProto(meta_info=batch.meta_info)) + batch.meta_info["is_offload_states"] = False if self.pipeline_config.async_pipeline: if should_eval: # 为Validation创建独立的DataProto @@ -508,7 +510,6 @@ def run(self): for reward_cluster in self.rewards.values(): reward_cluster.load_states() - batch.meta_info["is_offload_states"] = False scheduler_refs = {} for domain, scheduler in self.generate_schedulers.items(): scheduler_refs[domain] = scheduler.get_batch.remote( @@ -524,7 +525,7 @@ def run(self): dump_rollout_to_specific_path(self.pipeline_config.rollout_dump_dir, global_step, generate_output, self.tokenizer) generate_output.meta_info.pop("is_offload_states", None) - if not self.pipeline_config.async_pipeline: + if not self.pipeline_config.async_pipeline and self.pipeline_config.generate_opt_level == 1: for reward_cluster in self.rewards.values(): reward_cluster.offload_states() gen_metrics = self.actor_infer.stop_server() @@ -536,7 +537,7 @@ def run(self): batch = generate_output batch.meta_info["global_step"] = global_step - + with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: if self.is_lora: @@ -546,7 +547,7 @@ def run(self): else: if self.pipeline_config.reference.use_dynamic_batching_in_infer: batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, + batch, self.reference.dp_size, self.pipeline_config.reference.max_tokens_per_microbatch_in_infer, self.pipeline_config.reference.sequence_length_round_in_infer, @@ -567,7 +568,7 @@ def run(self): values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, + batch, self.actor_train.dp_size, self.pipeline_config.actor_train.max_tokens_per_microbatch_in_infer, self.pipeline_config.actor_train.sequence_length_round_in_infer, @@ -689,7 +690,7 @@ def run(self): # update actor if self.pipeline_config.actor_train.use_dynamic_batching_in_train: batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, + batch, self.actor_train.dp_size, self.pipeline_config.actor_train.max_tokens_per_microbatch_in_train, self.pipeline_config.actor_train.sequence_length_round_in_train, @@ -777,7 +778,7 @@ def val(self): self.val_generate_scheduler.get_batch.remote(data=batch, batch_size=len(self.val_dataset)), timeout=self.pipeline_config.rpc_timeout, ) - if not self.pipeline_config.async_pipeline: + if not self.pipeline_config.async_pipeline and self.pipeline_config.generate_opt_level == 1: self.actor_infer.stop_server() for reward_cluster in self.rewards.values(): reward_cluster.offload_states() diff --git a/tests/distributed/scheduler/test_protocol_padding.py b/tests/distributed/scheduler/test_protocol_padding.py new file mode 100644 index 000000000..bcab5f42d --- /dev/null +++ b/tests/distributed/scheduler/test_protocol_padding.py @@ -0,0 +1,163 @@ +"""Tests for automatic padding functionality in DynamicSamplingScheduler.""" + +import pytest +import torch +from unittest.mock import Mock, MagicMock + +from roll.distributed.scheduler.generate_scheduler import DynamicSamplingScheduler +from roll.distributed.scheduler.protocol import DataProto, pad_dataproto_to_divisor, unpad_dataproto + + +class TestDynamicSamplingSchedulerPadding: + """Test cases for padding functionality in DynamicSamplingScheduler.""" + + @pytest.fixture + def mock_scheduler(self): + """Create a mock DynamicSamplingScheduler for testing.""" + scheduler = Mock(spec=DynamicSamplingScheduler) + scheduler.actor_cluster = Mock() + scheduler.actor_cluster.dp_size = 4 + scheduler.actor_cluster.generate = Mock() + scheduler.generation_config = {"num_return_sequences": 1} + scheduler.is_val = False + scheduler.batch_size = 7 + scheduler.collect_fn = Mock() + scheduler.get_next_dataset_item = Mock() + scheduler.reward_scheduler = Mock() + scheduler.reward_clusters = [] + scheduler.pipeline_config = {} + scheduler.query_filter_fn = Mock(return_value=True) + scheduler.query_filter_count = 0 + scheduler.response_filter_count = 0 + scheduler.reset_status = Mock() + return scheduler + + def test_padding_when_batch_not_divisible_by_dp_size(self, mock_scheduler): + """Test padding when batch_size is not divisible by dp_size.""" + # Create test data + batch_size = 7 + dp_size = 4 + + # Create actual DataProto for testing + test_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (batch_size, 10)), + "attention_mask": torch.ones(batch_size, 10), + "position_ids": torch.arange(10).unsqueeze(0).repeat(batch_size, 1) + }) + + # Test padding logic with actual data + gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_data, dp_size) + + # Verify padding was applied + assert pad_size == 1 # 7 % 4 = 3, so pad_size = 4 - 3 = 1 + assert len(gen_batch_padded) == 8 + + def test_no_padding_when_batch_divisible_by_dp_size(self, mock_scheduler): + """Test no padding when batch_size is already divisible by dp_size.""" + batch_size = 8 + dp_size = 4 + + # Create actual DataProto for testing + test_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (batch_size, 10)), + "attention_mask": torch.ones(batch_size, 10), + "position_ids": torch.arange(10).unsqueeze(0).repeat(batch_size, 1) + }) + + # Test padding logic with actual data + gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_data, dp_size) + + # Verify no padding was needed + assert pad_size == 0 # 8 % 4 = 0, so no padding needed + assert len(gen_batch_padded) == 8 + + def test_unpadding_restores_original_size(self): + """Test that unpadding restores the original batch size.""" + # Create test data + original_size = 7 + pad_size = 1 + + # Create actual batch with padding + original_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (original_size, 10)), + "attention_mask": torch.ones(original_size, 10) + }) + + # Pad the data + padded_data, _ = pad_dataproto_to_divisor(original_data, 4) + + # Test unpadding + result = unpad_dataproto(padded_data, pad_size) + + # Verify unpadding was applied + assert len(result) == original_size + + def test_padding_preserves_data_integrity(self): + """Test that padding preserves data integrity.""" + # Create test DataProto + test_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (3, 10)), + "attention_mask": torch.ones(3, 10) + }) + + # Apply padding + padded_data, pad_size = pad_dataproto_to_divisor(test_data, 4) + + # Verify padded data size + assert len(padded_data) == 4 + assert pad_size == 1 + + # Verify data integrity using proper TensorDict methods + assert "input_ids" in padded_data.batch.keys() + assert "attention_mask" in padded_data.batch.keys() + assert padded_data.batch["input_ids"].shape[0] == 4 + assert padded_data.batch["attention_mask"].shape[0] == 4 + + def test_edge_case_empty_batch(self): + """Test padding behavior with empty batch.""" + # Create empty DataProto + empty_data = DataProto.from_single_dict({ + "input_ids": torch.empty(0, 10), + "attention_mask": torch.empty(0, 10) + }) + + # Apply padding + padded_data, pad_size = pad_dataproto_to_divisor(empty_data, 4) + + # Verify empty batch handling + assert pad_size == 0 + assert len(padded_data) == 0 + + def test_edge_case_single_item_batch(self): + """Test padding behavior with single item batch.""" + # Create single item DataProto + single_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (1, 10)), + "attention_mask": torch.ones(1, 10) + }) + + # Apply padding + padded_data, pad_size = pad_dataproto_to_divisor(single_data, 4) + + # Verify padding + assert pad_size == 3 + assert len(padded_data) == 4 + + def test_backward_compatibility(self, mock_scheduler): + """Test that padding doesn't break existing functionality.""" + batch_size = 8 # Already divisible by dp_size + dp_size = 4 + + # Create actual DataProto for testing + test_data = DataProto.from_single_dict({ + "input_ids": torch.randint(0, 1000, (batch_size, 10)), + "attention_mask": torch.ones(batch_size, 10), + "position_ids": torch.arange(10).unsqueeze(0).repeat(batch_size, 1) + }) + + # Test that existing flow still works + gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_data, dp_size) + + # Verify no padding was applied and flow is unchanged + assert pad_size == 0 + assert len(gen_batch_padded) == batch_size \ No newline at end of file diff --git a/tests/distributed/strategy/test_vllm_strategy_beam_search.py b/tests/distributed/strategy/test_vllm_strategy_beam_search.py new file mode 100644 index 000000000..317799279 --- /dev/null +++ b/tests/distributed/strategy/test_vllm_strategy_beam_search.py @@ -0,0 +1,242 @@ +import pytest +import torch +import sys +from unittest.mock import Mock, patch, MagicMock + +# Mock vllm modules before importing +mock_vllm = Mock() +mock_vllm.__version__ = "0.8.4" +sys.modules['vllm'] = mock_vllm +sys.modules['vllm.sampling_params'] = Mock() +sys.modules['vllm.beam_search'] = Mock() +sys.modules['vllm.lora'] = Mock() +sys.modules['vllm.lora.request'] = Mock() +sys.modules['vllm.utils'] = Mock() +sys.modules['roll.third_party.vllm'] = Mock() + +# Create mock classes +class MockRequestOutput: + def __init__(self): + self.request_id = "test_request" + self.outputs = [Mock()] + self.outputs[0].token_ids = [100, 200, 300] + self.outputs[0].finish_reason = "length" + self.outputs[0].logprobs = None + self.finished = True + +class MockSamplingParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.n = kwargs.get('n', 1) + self.max_tokens = kwargs.get('max_tokens', 50) + +class MockBeamSearchParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.beam_width = kwargs.get('beam_width', 1) + self.max_tokens = kwargs.get('max_tokens', 50) + +class MockBeamSearchSequence: + def __init__(self, tokens, logprobs, cum_logprob): + self.tokens = tokens + self.logprobs = logprobs + self.cum_logprob = cum_logprob + +class MockBeamSearchOutput: + def __init__(self, sequences): + self.sequences = sequences + +class MockLoRARequest: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + +# Set up the mocks +sys.modules['vllm'].RequestOutput = MockRequestOutput +sys.modules['vllm'].SamplingParams = MockSamplingParams +sys.modules['vllm.sampling_params'].RequestOutputKind = Mock() +sys.modules['vllm.sampling_params'].BeamSearchParams = MockBeamSearchParams +sys.modules['vllm.beam_search'].BeamSearchOutput = MockBeamSearchOutput +sys.modules['vllm.beam_search'].BeamSearchSequence = MockBeamSearchSequence +sys.modules['vllm.lora.request'].LoRARequest = MockLoRARequest +sys.modules['vllm.utils'].random_uuid = Mock(return_value="test_uuid") + +# Now import the actual modules +from roll.distributed.scheduler.protocol import DataProto +from roll.distributed.strategy.vllm_strategy import VllmStrategy +from roll.distributed.executor.worker import Worker + + +class TestVllmStrategyBeamSearch: + """Test cases for VllmStrategy beam search functionality.""" + + @pytest.fixture + def mock_worker(self): + """Create a mock worker for testing.""" + worker = Mock(spec=Worker) + worker.pipeline_config = Mock() + worker.pipeline_config.seed = 42 + worker.worker_config = Mock() + worker.worker_config.strategy_args = Mock() + worker.worker_config.strategy_args.strategy_config = {} + worker.worker_config.model_args = Mock() + worker.worker_config.model_args.model_name_or_path = "test_model" + worker.worker_config.model_args.dtype = "fp16" + worker.worker_config.model_args.lora_target = None + worker.get_free_port = Mock(return_value=12345) + worker.rank = 0 + worker.world_size = 1 + worker.rank_info = Mock() + worker.rank_info.dp_rank = 0 + worker.rank_info.dp_size = 1 + return worker + + @pytest.fixture + def vllm_strategy(self, mock_worker): + """Create VllmStrategy instance for testing.""" + strategy = VllmStrategy(mock_worker) + + # Mock the model and tokenizer + strategy.model = Mock() + strategy.tokenizer = Mock() + strategy.tokenizer.pad_token_id = 0 + strategy.is_lora = False + strategy.is_model_in_gpu = True + + return strategy + + @pytest.fixture + def sample_batch(self): + """Create a sample batch for testing.""" + batch_size = 2 + seq_length = 10 + + # Create sample input tensors + input_ids = torch.randint(1, 1000, (batch_size, seq_length)) + attention_mask = torch.ones(batch_size, seq_length) + + batch = DataProto.from_single_dict({ + "input_ids": input_ids, + "attention_mask": attention_mask + }) + + return batch + + def test_should_use_beam_search_detection(self, vllm_strategy): + """Test beam search detection logic.""" + + # Test with num_beams > 1 + config_with_beam = {"num_beams": 3, "max_new_tokens": 50} + assert vllm_strategy._should_use_beam_search(config_with_beam) is True + + # Test with use_beam_search flag + config_with_flag = {"use_beam_search": True, "max_new_tokens": 50} + assert vllm_strategy._should_use_beam_search(config_with_flag) is True + + # Test without beam search parameters + config_without_beam = {"max_new_tokens": 50, "temperature": 0.8} + assert vllm_strategy._should_use_beam_search(config_without_beam) is False + + # Test with num_beams = 1 + config_single_beam = {"num_beams": 1, "max_new_tokens": 50} + assert vllm_strategy._should_use_beam_search(config_single_beam) is False + + def test_generate_with_beam_search_success(self, vllm_strategy, sample_batch): + """Test successful beam search generation.""" + generation_config = {"num_beams": 3, "max_new_tokens": 50} + + # Create mock beam search outputs + beam_width = 3 + batch_size = 2 + + beam_search_outputs = [] + for batch_idx in range(batch_size): + sequences = [] + for beam_idx in range(beam_width): + # Include prompt + generated tokens + prompt_length = 10 + generated_tokens = [100 + beam_idx, 200 + beam_idx, 300 + beam_idx] + full_tokens = list(range(prompt_length)) + generated_tokens + + sequence = MockBeamSearchSequence( + tokens=full_tokens, + logprobs=[], + cum_logprob=-1.0 * beam_idx + ) + sequences.append(sequence) + + output = MockBeamSearchOutput(sequences=sequences) + beam_search_outputs.append(output) + + # Mock the beam_search method + vllm_strategy.model.beam_search = Mock(return_value=beam_search_outputs) + + # Mock breakpoint to avoid actual debugging + with patch('builtins.breakpoint'): + result = vllm_strategy.generate(sample_batch, generation_config) + + # Verify beam_search was called + vllm_strategy.model.beam_search.assert_called_once() + + # Check result shape + assert result.shape[0] == batch_size * beam_width # 2 * 3 = 6 + assert result.shape[1] >= 13 # prompt_length + generated_tokens + + def test_generate_with_beam_search_multimodal(self, vllm_strategy): + """Test beam search generation with multimodal data.""" + generation_config = {"num_beams": 2, "max_new_tokens": 30} + + # Create multimodal batch + multimodal_data = [ + { + "prompt_token_ids": [1, 2, 3, 4, 5], + "multi_modal_data": {"image": "test_image.jpg"} + }, + { + "prompt_token_ids": [6, 7, 8, 9, 10], + "multi_modal_data": {"image": "test_image2.jpg"} + } + ] + + # Create a batch with dummy tensors to satisfy DataProto requirements + batch = DataProto.from_single_dict({ + "input_ids": torch.randint(1, 1000, (2, 5)), + "attention_mask": torch.ones(2, 5) + }) + batch.non_tensor_batch["multi_modal_data"] = multimodal_data + + # Create mock beam search outputs + beam_search_outputs = [] + for batch_idx in range(2): + sequences = [] + for beam_idx in range(2): + prompt_length = 5 + generated_tokens = [100 + beam_idx, 200 + beam_idx] + full_tokens = multimodal_data[batch_idx]["prompt_token_ids"] + generated_tokens + + sequence = MockBeamSearchSequence( + tokens=full_tokens, + logprobs=[], + cum_logprob=-1.0 * beam_idx + ) + sequences.append(sequence) + + output = MockBeamSearchOutput(sequences=sequences) + beam_search_outputs.append(output) + + # Mock the beam_search method + vllm_strategy.model.beam_search = Mock(return_value=beam_search_outputs) + + # Mock breakpoint to avoid actual debugging + with patch('builtins.breakpoint'): + result = vllm_strategy.generate(batch, generation_config) + + # Verify beam_search was called with correct prompts + vllm_strategy.model.beam_search.assert_called_once() + call_args = vllm_strategy.model.beam_search.call_args + assert call_args[1]['prompts'] == multimodal_data + + # Check result shape + assert result.shape[0] == 4 # batch_size * beam_width From d8e5c94b415298a0cd3adbc1dec53aeec25d7b99 Mon Sep 17 00:00:00 2001 From: "hongzhen.yj" Date: Wed, 5 Nov 2025 16:59:35 +0800 Subject: [PATCH 10/58] (fix): ensure compatibility with transformers version check for causal mask update. --- roll/utils/context_parallel/monkey_patch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/roll/utils/context_parallel/monkey_patch.py b/roll/utils/context_parallel/monkey_patch.py index c779ce9fa..b64b3c339 100644 --- a/roll/utils/context_parallel/monkey_patch.py +++ b/roll/utils/context_parallel/monkey_patch.py @@ -1,6 +1,8 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.qwen2.modeling_qwen2 import Qwen2Model + + from roll.utils.logging import get_logger from roll.utils.packages import is_transformers_version_greater_than From 77325f7a827d4ab1485b5db83c130804efbbf3c6 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 16:52:43 +0800 Subject: [PATCH 11/58] (feat): support pytorch280 docker. --- roll/third_party/vllm/__init__.py | 2 +- roll/third_party/vllm/vllm_0_11_0/llm.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/roll/third_party/vllm/__init__.py b/roll/third_party/vllm/__init__.py index 5a504d0f7..3f6c19a28 100644 --- a/roll/third_party/vllm/__init__.py +++ b/roll/third_party/vllm/__init__.py @@ -21,7 +21,7 @@ elif Version("0.10.2") == Version(vllm.__version__): from roll.third_party.vllm.vllm_0_10_2.llm import Llm0102 LLM = Llm0102 -elif Version("0.11.0") == Version(vllm.__version__): +elif Version("0.11.1rc2.dev0+gc3a722fcb.d20251021") == Version(vllm.__version__) or Version("0.11.0") == Version(vllm.__version__): from roll.third_party.vllm.vllm_0_11_0.llm import Llm0110 LLM = Llm0110 else: diff --git a/roll/third_party/vllm/vllm_0_11_0/llm.py b/roll/third_party/vllm/vllm_0_11_0/llm.py index ec4bff644..dc75e1340 100644 --- a/roll/third_party/vllm/vllm_0_11_0/llm.py +++ b/roll/third_party/vllm/vllm_0_11_0/llm.py @@ -7,9 +7,18 @@ import torch from pydantic import ValidationError from vllm import LLM, EngineArgs, SamplingParams, envs -from vllm.config import (CompilationConfig, StructuredOutputsConfig, - ModelDType, TokenizerMode, - is_init_field) +from vllm.config import ( + CompilationConfig, + StructuredOutputsConfig, + is_init_field, +) + +try: + # 0.11.1rc2.dev0+gc3a722fcb.d20251021 has import diff + from vllm.config.model import ModelDType, TokenizerMode +except ImportError: + from vllm.config import ModelDType, TokenizerMode + from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides, PoolerConfig, RunnerOption) From accefedd704ecad2bfc454627f501b3357ba1af7 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 7 Nov 2025 14:52:35 +0800 Subject: [PATCH 12/58] (fix): fix agentic val get_batch state in redundancy env. --- roll/pipeline/agentic/agentic_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index dc7a917a8..f9c19c302 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -146,6 +146,8 @@ def __post_init__(self): if self.val_batch_size < 0: self.val_env_manager.max_traj_per_env = sys.maxsize else: + assert self.val_batch_size % val_env_num == 0, f"val_batch_size {self.val_batch_size} must be divisible by val_env_num {val_env_num}, equal best" + traj_per_env = (self.val_batch_size + val_env_num - 1) // val_env_num if self.val_env_manager.max_traj_per_env < 0: self.val_env_manager.max_traj_per_env = traj_per_env From 8629e85089fc5d1ecaedd84c0f967ed4c3ebae5e Mon Sep 17 00:00:00 2001 From: bzd02333762 Date: Tue, 18 Nov 2025 15:45:04 +0800 Subject: [PATCH 13/58] (feat): Add support for Qwen-3-next on AMD GPUs. --- .../agent_val_frozen_lake_amd.yaml | 5 +- .../agent_val_frozen_lake_async_amd.yaml | 163 +++++++++++ .../submit_pipeline_amd.sh | 49 ++++ .../submit_pipeline_amd_async.sh | 49 ++++ .../rlvr_config_amd.yaml | 38 ++- .../rlvr_config_amd_async.yaml | 157 ++++++++++ .../rlvr_lora_zero3_amd.yaml | 267 ++++++++++++++++++ .../submit_pipeline_amd.sh | 49 ++++ .../submit_pipeline_amd_async.sh | 49 ++++ .../submit_pipeline_amd_zero3_lora.sh | 46 +++ .../qwen2.5-vl-7B-math/submit_pipeline_amd.sh | 49 ++++ .../rlvr_config_amd.yaml | 267 ++++++++++++++++++ .../submit_pipeline_amd.sh | 49 ++++ .../rlvr_config_amd.yaml | 10 +- .../submit_pipeline_amd.sh | 49 ++++ .../rlvr_config_amd.yaml | 196 +++++++++++++ .../submit_pipeline_amd.sh | 49 ++++ .../src/mcore_adapter/platforms/rocm.py | 7 +- roll/distributed/strategy/vllm_strategy.py | 20 +- roll/models/model_providers.py | 2 +- roll/platforms/rocm.py | 5 +- 21 files changed, 1542 insertions(+), 33 deletions(-) create mode 100644 examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_async_amd.yaml create mode 100644 examples/qwen2.5-0.5B-agentic/submit_pipeline_amd.sh create mode 100644 examples/qwen2.5-0.5B-agentic/submit_pipeline_amd_async.sh create mode 100644 examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd_async.yaml create mode 100644 examples/qwen2.5-7B-rlvr_megatron/rlvr_lora_zero3_amd.yaml create mode 100644 examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd.sh create mode 100644 examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_async.sh create mode 100644 examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_zero3_lora.sh create mode 100644 examples/qwen2.5-vl-7B-math/submit_pipeline_amd.sh create mode 100644 examples/qwen3-235BA22B-rlvr_megatron/rlvr_config_amd.yaml create mode 100644 examples/qwen3-235BA22B-rlvr_megatron/submit_pipeline_amd.sh create mode 100644 examples/qwen3-30BA3B-rlvr_megatron/submit_pipeline_amd.sh create mode 100644 examples/qwen3-next-80BA3B-rlvr_megatron/rlvr_config_amd.yaml create mode 100644 examples/qwen3-next-80BA3B-rlvr_megatron/submit_pipeline_amd.sh diff --git a/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_amd.yaml b/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_amd.yaml index 99d4877ee..4d8efa742 100644 --- a/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_amd.yaml +++ b/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_amd.yaml @@ -107,7 +107,7 @@ actor_infer: strategy_args: strategy_name: vllm strategy_config: - gpu_memory_utilization: 0.4 + gpu_memory_utilization: 0.6 block_size: 16 load_format: auto device_mapping: list(range(0,8)) @@ -131,7 +131,6 @@ reward_normalization: method: mean_std # asym_clip / identity / mean_std train_env_manager: - format_penalty: -0.15 # sokoban env penalty_for_step=-0.1 max_env_num_per_worker: 16 num_env_groups: 128 # under the same group, the env config and env seed are ensured to be equal @@ -163,8 +162,8 @@ custom_envs: ${custom_env.FrozenLakeThink} FrozenLakeLocallyDefineExamples: # Can import from unified envs config or define dict locally env_type: frozen_lake + max_steps: ${max_actions_per_traj} max_tokens_per_step: ${max_tokens_per_step} - user_prompt_format: ${user_prompt_think_format} env_manager_cls: ${env_manager_cls} use_thread_lock: true env_config: diff --git a/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_async_amd.yaml b/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_async_amd.yaml new file mode 100644 index 000000000..3696af5ec --- /dev/null +++ b/examples/qwen2.5-0.5B-agentic/agent_val_frozen_lake_async_amd.yaml @@ -0,0 +1,163 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_async" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_sokoban +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_frozen_lake_async + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +async_generation_ratio: 1 + +rollout_batch_size: 1024 +val_batch_size: 1024 +sequence_length: 8192 + +advantage_clip: 0.2 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 +max_grad_norm: 1.0 + +pretrain: Qwen/Qwen2.5-0.5B-Instruct +reward_pretrain: Qwen/Qwen2.5-0.5B-Instruct + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 2 + gradient_accumulation_steps: 128 + warmup_steps: 10 + lr_scheduler_type: cosine + data_args: + template: qwen2_5 + strategy_args: +# strategy_name: deepspeed_train +# strategy_config: ${deepspeed_zero3} + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,4)) + infer_batch_size: 2 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 128 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.6 + block_size: 16 + load_format: auto + device_mapping: list(range(4,8)) + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,4)) + infer_batch_size: 2 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.15 # sokoban env penalty_for_step=-0.1 + max_env_num_per_worker: 16 + num_env_groups: 128 + # under the same group, the env config and env seed are ensured to be equal + group_size: 8 + tags: [FrozenLake] + num_groups_partition: [128] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + max_env_num_per_worker: 32 + num_env_groups: 1024 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + tags: [SimpleSokoban, LargerSokoban, SokobanDifferentGridVocab, FrozenLake] + num_groups_partition: [256, 256, 256, 256] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +# Here, you can override variables defined in the imported envs. max_tokens_per_step: 128 in custom_env.SimpleSokoban, here replaced by 64 +max_tokens_per_step: 64 + +custom_envs: + SimpleSokoban: + ${custom_env.SimpleSokoban} + LargerSokoban: + ${custom_env.LargerSokoban} + SokobanDifferentGridVocab: + ${custom_env.SokobanDifferentGridVocab} + FrozenLake: + ${custom_env.FrozenLake} + FrozenLakeThink: + ${custom_env.FrozenLakeThink} \ No newline at end of file diff --git a/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd.sh b/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd.sh new file mode 100644 index 000000000..54d095440 --- /dev/null +++ b/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=1 +CONFIG_FILE="agent_val_frozen_lake_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_agentic_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd_async.sh b/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd_async.sh new file mode 100644 index 000000000..aa06c2054 --- /dev/null +++ b/examples/qwen2.5-0.5B-agentic/submit_pipeline_amd_async.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=1 +CONFIG_FILE="agent_val_frozen_lake_async_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_agentic_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd.yaml b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd.yaml index c80f24dc0..4d95333a6 100644 --- a/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd.yaml +++ b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd.yaml @@ -9,6 +9,7 @@ logging_dir: ./output/logs output_dir: ./output system_envs: USE_MODELSCOPE: '1' + checkpoint_config: type: file_system output_dir: /data/cpfs_0/rl_examples/models/${exp_name} @@ -34,38 +35,25 @@ logging_steps: 1 eval_steps: 10 resume_from_checkpoint: false -# -------------------------- -# grpo related rollout_batch_size: 64 # prompt prompt_length: 2048 response_length: 4096 -adv_estimator: "grpo" num_return_sequences_in_group: 8 ppo_epochs: 1 -use_kl_loss: true -kl_loss_coef: 0.001 -loss_agg_mode: "seq-mean-token-sum" - +adv_estimator: "reinforce" -# ppo related -# advantage -whiten_advantages: true -advantage_clip: 2.0 -dual_clip_loss: true # clip +value_clip: 0.5 reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true # normalize norm_mean_type: ~ norm_std_type: ~ -# reward -add_token_level_kl: false - -# -------------------------- -# Additional optimization knobs # data mask max_len_mask: true difficulty_mask: true @@ -77,13 +65,17 @@ error_max_len_clip: false difficulty_loss_weight: false length_loss_weight: false +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + # dynamic sampling scheduler # use_additional_prompts: true # max_running_requests: 256 # is_num_return_sequences_expand: false -# -------------------------- - pretrain: Qwen/Qwen2.5-7B reward_pretrain: Qwen/Qwen2.5-7B @@ -253,7 +245,11 @@ rewards: data_args: template: qwen2_5 strategy_args: - strategy_name: hf_infer - strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto device_mapping: list(range(12,16)) infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd_async.yaml b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd_async.yaml new file mode 100644 index 000000000..579245ad3 --- /dev/null +++ b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_amd_async.yaml @@ -0,0 +1,157 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-async-rlvr-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/lzc/models/${exp_name} + +track_with: ml_tracker + +num_gpus_per_node: 8 + +max_steps: 1000 +save_steps: 100 +logging_steps: 1 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 8192 + +async_generation_ratio: 1 +is_num_return_sequences_expand: true +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +# pretrain: /data/cpfs_0/common/models/Qwen-2.5-7B-Instruct +# reward_pretrain: /data/cpfs_0/common/models/Qwen-2.5-7B-Instruct + +pretrain: Qwen/Qwen2.5-7B +reward_pretrain: Qwen/Qwen2.5-7B + +actor_train: + model_args: + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 64 + warmup_steps: 1 + num_train_epochs: 5 + data_args: + template: qwen2_5 + file_name: + - data/math_deepmath_deal.jsonl + domain_interleave_probs: + math_rule: 1.0 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + sequence_parallel: true + use_distributed_optimizer: true + apply_rope_fusion: true + overlap_grad_reduce: true + bias_activation_fusion: true + bf16: true + device_mapping: list(range(0,16)) + infer_batch_size: 2 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + data_args: + template: qwen2_5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.6 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(16,24)) + infer_batch_size: 1 + +reference: + model_args: + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + bf16: true + device_mapping: list(range(0,16)) + infer_batch_size: 2 + +rewards: + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 diff --git a/examples/qwen2.5-7B-rlvr_megatron/rlvr_lora_zero3_amd.yaml b/examples/qwen2.5-7B-rlvr_megatron/rlvr_lora_zero3_amd.yaml new file mode 100644 index 000000000..26119abc7 --- /dev/null +++ b/examples/qwen2.5-7B-rlvr_megatron/rlvr_lora_zero3_amd.yaml @@ -0,0 +1,267 @@ +defaults: + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll_examples +# notes: roll_examples +# tags: +# - rlvr +# - baseline + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +reward_norm: null +reward_shift: false +reward_scale: false + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# lora +lora_target: o_proj,q_proj,k_proj,v_proj +lora_rank: 32 +lora_alpha: 32 + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B +reward_pretrain: Qwen/Qwen2.5-7B + +validation: + data_args: + template: qwen2_5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + attn_implementation: fa2 + # Recomputed tensor size does not match for LoRA with Zero3 when activating checkpointing, See https://github.com/huggingface/transformers/issues/34928 for details + disable_gradient_checkpointing: true + dtype: bf16 + lora_target: ${lora_target} + lora_rank: ${lora_rank} + lora_alpha: ${lora_alpha} + model_type: ~ + training_args: + learning_rate: 1.0e-5 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2_5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: deepspeed_train + strategy_config: ${deepspeed_zero3} + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + lora_target: ${lora_target} + lora_rank: ${lora_rank} + lora_alpha: ${lora_alpha} + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + data_args: + template: qwen2_5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.6 + # https://github.com/vllm-project/vllm/issues/9452 + enforce_eager: false + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,16)) + infer_batch_size: 8 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 +# dynamic filter config +# query_filter_config: +# type: mean_filter +# filter_args: +# threshold_up: 0.9 +# threshold_down: 0.1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2_5 + world_size: 8 + infer_batch_size: 1 +# query_filter_config: +# type: std_filter +# filter_args: +# std_threshold: 0 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: null + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd.sh b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd.sh new file mode 100644 index 000000000..fccb1ab1c --- /dev/null +++ b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=2 +CONFIG_FILE="rlvr_config_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_async.sh b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_async.sh new file mode 100644 index 000000000..484218310 --- /dev/null +++ b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_async.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=3 +CONFIG_FILE="rlvr_config_amd_async.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_zero3_lora.sh b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_zero3_lora.sh new file mode 100644 index 000000000..25016bfa5 --- /dev/null +++ b/examples/qwen2.5-7B-rlvr_megatron/submit_pipeline_amd_zero3_lora.sh @@ -0,0 +1,46 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=2 +CONFIG_FILE="rlvr_lora_zero3_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + +QUEUE="nebula_test2_308x_gpu_hang" + + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch260_rocm700rc4 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch260_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen2.5-vl-7B-math/submit_pipeline_amd.sh b/examples/qwen2.5-vl-7B-math/submit_pipeline_amd.sh new file mode 100644 index 000000000..2f5102eb3 --- /dev/null +++ b/examples/qwen2.5-vl-7B-math/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=2 +CONFIG_FILE="rlvr_math_megatron_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_vlmath_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0,MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK=1" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen3-235BA22B-rlvr_megatron/rlvr_config_amd.yaml b/examples/qwen3-235BA22B-rlvr_megatron/rlvr_config_amd.yaml new file mode 100644 index 000000000..7e81a8a11 --- /dev/null +++ b/examples/qwen3-235BA22B-rlvr_megatron/rlvr_config_amd.yaml @@ -0,0 +1,267 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen3-235BA22B-rlvr-config_amd" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + +checkpoint_config: + type: file_system + output_dir: ./rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen3-235B-A22B +reward_pretrain: Qwen/Qwen3-235B-A22B + +validation: + data_args: + template: qwen3 + file_name: + - data/math_benchmarks.jsonl + generating_args: + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + eval_steps: 10 + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 64 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen3 + file_name: + - data/code_KodCode_data.jsonl + # - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + # llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 8 + virtual_pipeline_model_parallel_size: 6 + expert_model_parallel_size: 8 + context_parallel_size: 1 + account_for_loss_in_pipeline_split: true + account_for_embedding_in_pipeline_split: true + use_distributed_optimizer: true + sequence_parallel: true + overlap_grad_reduce: true + bias_activation_fusion: true + apply_rope_fusion: true + moe_grouped_gemm: true + moe_layer_recompute: true + moe_token_dispatcher_type: "alltoall" + device_mapping: list(range(0,256)) + infer_batch_size: 2 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + data_args: + template: qwen3 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.6 + load_format: dummy + tensor_parallel_size: 8 + enforce_eager: true + num_gpus_per_worker: 8 + device_mapping: list(range(0,200)) # device share with llm reward + infer_batch_size: 1 + +reference: + model_args: + dtype: bf16 + model_type: ~ + data_args: + template: qwen3 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 8 + virtual_pipeline_model_parallel_size: 6 + expert_model_parallel_size: 8 + account_for_loss_in_pipeline_split: true + account_for_embedding_in_pipeline_split: true + use_distributed_optimizer: true + sequence_parallel: true + bias_activation_fusion: true + apply_rope_fusion: true + moe_grouped_gemm: true + moe_token_dispatcher_type: "alltoall" + device_mapping: list(range(0,256)) + infer_batch_size: 2 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen3 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen3 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen3 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 +# dynamic filter config +# query_filter_config: +# type: mean_filter +# filter_args: +# threshold_up: 0.9 +# threshold_down: 0.1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen3 + world_size: 8 + infer_batch_size: 1 +# query_filter_config: +# type: std_filter +# filter_args: +# std_threshold: 0 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen3 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.75 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(200,256)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen3-235BA22B-rlvr_megatron/submit_pipeline_amd.sh b/examples/qwen3-235BA22B-rlvr_megatron/submit_pipeline_amd.sh new file mode 100644 index 000000000..bf10ec75a --- /dev/null +++ b/examples/qwen3-235BA22B-rlvr_megatron/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=32 +CONFIG_FILE="rlvr_config_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0,NCCL_DEBUG=INFO" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen3-30BA3B-rlvr_megatron/rlvr_config_amd.yaml b/examples/qwen3-30BA3B-rlvr_megatron/rlvr_config_amd.yaml index 5c2cb1cb1..4af2f0893 100644 --- a/examples/qwen3-30BA3B-rlvr_megatron/rlvr_config_amd.yaml +++ b/examples/qwen3-30BA3B-rlvr_megatron/rlvr_config_amd.yaml @@ -253,7 +253,13 @@ rewards: data_args: template: qwen2_5 strategy_args: - strategy_name: hf_infer - strategy_config: null + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto device_mapping: list(range(24,32)) infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen3-30BA3B-rlvr_megatron/submit_pipeline_amd.sh b/examples/qwen3-30BA3B-rlvr_megatron/submit_pipeline_amd.sh new file mode 100644 index 000000000..f2937e32e --- /dev/null +++ b/examples/qwen3-30BA3B-rlvr_megatron/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=4 +CONFIG_FILE="rlvr_config_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/examples/qwen3-next-80BA3B-rlvr_megatron/rlvr_config_amd.yaml b/examples/qwen3-next-80BA3B-rlvr_megatron/rlvr_config_amd.yaml new file mode 100644 index 000000000..db646a928 --- /dev/null +++ b/examples/qwen3-next-80BA3B-rlvr_megatron/rlvr_config_amd.yaml @@ -0,0 +1,196 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen3-next-80BA3B-rlvr-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + +checkpoint_config: + type: file_system + output_dir: ./rl_examples/models/${exp_name} + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll_examples +# notes: roll_examples +# tags: +# - rlvr +# - baseline + +track_with: tensorboard +tracker_kwargs: + log_dir: ./roll_exp/rlvr/${exp_name}/ + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 6144 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +reward_norm: null +reward_shift: false +reward_scale: false + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen3-Next-80B-A3B-Instruct +reward_pretrain: Qwen/Qwen3-Next-80B-A3B-Instruct + +# validation: +# data_args: +# template: qwen2_5 +# file_name: +# - data/aime24_25_deal.jsonl +# generating_args: +# top_p: 0.6 +# top_k: 50 +# num_beams: 1 +# temperature: 0.6 +# num_return_sequences: 1 +# eval_steps: 10 + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 1 + num_train_epochs: 5 + data_args: + template: native + file_name: + - data/math_deepmath_deal.jsonl + domain_interleave_probs: + math_rule: 1.0 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + expert_model_parallel_size: 8 + pipeline_model_parallel_size: 4 + virtual_pipeline_model_parallel_size: 12 + context_parallel_size: 1 + use_distributed_optimizer: true + # account_for_loss_in_pipeline_split: true + moe_token_dispatcher_type: alltoall + recompute_granularity: selective + recompute_modules: "moe" + bias_activation_fusion: true + moe_grouped_gemm: true + moe_shared_expert_overlap: true + bf16: true + additional_configs: + moe_permute_fusion: true + device_mapping: list(range(0,64)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + data_args: + template: native + strategy_args: + strategy_name: vllm + strategy_config: + tensor_parallel_size: 4 + gpu_memory_utilization: 0.6 + block_size: 16 + max_model_len: 8192 + enforce_eager: true + device_mapping: list(range(0,64)) + infer_batch_size: 1 + +reference: + model_args: + dtype: bf16 + model_type: ~ + data_args: + template: native + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + expert_model_parallel_size: 8 + pipeline_model_parallel_size: 2 + virtual_pipeline_model_parallel_size: 12 + use_distributed_optimizer: true + moe_token_dispatcher_type: alltoall + bias_activation_fusion: true + moe_grouped_gemm: true + moe_shared_expert_overlap: true + additional_configs: + moe_permute_fusion: true + device_mapping: list(range(0,64)) + infer_batch_size: 1 + +rewards: + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: native + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 diff --git a/examples/qwen3-next-80BA3B-rlvr_megatron/submit_pipeline_amd.sh b/examples/qwen3-next-80BA3B-rlvr_megatron/submit_pipeline_amd.sh new file mode 100644 index 000000000..8d10e0037 --- /dev/null +++ b/examples/qwen3-next-80BA3B-rlvr_megatron/submit_pipeline_amd.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set +x +source "examples/scripts/config.sh" + +WORKER_COUNT=8 +CONFIG_FILE="rlvr_config_amd.yaml" +# 替换为mos uri +NEBULA_MODEL="" +ENTRY_FILE="examples/start_rlvr_pipeline.py" + +CONFIG_PATH=$(basename $(dirname $0)) +CONFIG_NAME="${CONFIG_FILE%.yaml}" +JOB_NAME="$CONFIG_PATH-$CONFIG_NAME" + + +QUEUE="nebula_test2_308x_gpu_hang" +# QUEUE="nebula_test_308x" +ENVS="NCCL_PF_UCM_TIMEOUT=600000,NCCL_SOCKET_IFNAME=bond0" +# ENVS="NCCL_PF_UCM_TIMEOUT=600000" + +echo "JOB_NAME: ${JOB_NAME}" +echo "WORKER_COUNT: ${WORKER_COUNT}" +echo "CONFIG_NAME: ${CONFIG_NAME}" +echo "CONFIG_PATH: ${CONFIG_PATH}" +echo "ENTRY_FILE: ${ENTRY_FILE}" + +args="--config_name ${CONFIG_NAME} --config_path ${CONFIG_PATH}" + +mdl_args="--queue=${QUEUE} \ + --entry=${ENTRY_FILE} \ + --worker_count=${WORKER_COUNT} \ + --file.cluster_file=examples/scripts/cluster.json \ + --job_name=${JOB_NAME} \ + --algo_name=pytorch280 \ + --requirements_file_name=nebula_patch/requirements/requirements_torch280_vllm_amd.txt \ + --oss_appendable=true \ + --_NEBULA_MODEL=${NEBULA_MODEL} \ + --nebula_model=${NEBULA_MODEL} \ + --env=${ENVS} \ + --force \ + " +if [ -n "${OPENLM_TOKEN}" ]; then + mdl_args="${mdl_args} --env=OPENLM_TOKEN=${OPENLM_TOKEN}" +fi + +echo ${args} +echo ${mdl_args} + +nebulactl run mdl --user_params="${args}" $mdl_args diff --git a/mcore_adapter/src/mcore_adapter/platforms/rocm.py b/mcore_adapter/src/mcore_adapter/platforms/rocm.py index 6a04c06ed..0df5fdfa6 100644 --- a/mcore_adapter/src/mcore_adapter/platforms/rocm.py +++ b/mcore_adapter/src/mcore_adapter/platforms/rocm.py @@ -34,14 +34,17 @@ def get_custom_env_vars(cls) -> dict: "VLLM_ALLOW_INSECURE_SERIALIZATION": "1", # These VLLM related enviroment variables are related to backend. maybe used afterwards. # "VLLM_USE_TRITON_FLASH_ATTN":"0", - # "VLLM_ROCM_USE_AITER":"1", + "VLLM_ROCM_USE_AITER":"1", # "VLLM_ROCM_USE_AITER_MOE":"1", # "VLLM_ROCM_USE_AITER_ASMMOE":"1", # "VLLM_ROCM_USE_AITER_PAGED_ATTN":"1", # "RAY_DEBUG": "legacy", - "VLLM_USE_V1": "0", + "VLLM_USE_V1": "1", "TORCHINDUCTOR_COMPILE_THREADS": "2", "PYTORCH_HIP_ALLOC_CONF": "expandable_segments:True", + "SAFETENSORS_FAST_GPU":"1", + "VLLM_ROCM_USE_AITER_MHA":"0", + "VLLM_ALLOW_LONG_MAX_MODEL_LEN":"1", # "NCCL_DEBUG_SUBSYS":"INIT,COLL", # "NCCL_DEBUG":"INFO", # "NCCL_DEBUG_FILE":"rccl.%h.%p.log", diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index f44dd29f8..00ecacb10 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -27,6 +27,13 @@ from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform +try: + from vllm.inputs import TokensPrompt + high_version_vllm=True +except: + high_version_vllm=False + pass + logger = get_logger() @@ -146,9 +153,16 @@ def _generate_standard(self, batch: DataProto, generation_config) -> torch.Tenso if "multi_modal_data" in batch.non_tensor_batch: vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"] else: - vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( - input_ids=input_ids, attention_mask=attention_mask - ) + if high_version_vllm: + prompt_token_ids_list=gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) + + vllm_input_args["prompts"] = [TokensPrompt(prompt_token_ids=prompt_token_ids)for prompt_token_ids in prompt_token_ids_list] + else: + vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( + input_ids=input_ids, attention_mask=attention_mask + ) lora_requests = None if self.is_lora: diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py index 2e8eb33e7..54ecfc6db 100644 --- a/roll/models/model_providers.py +++ b/roll/models/model_providers.py @@ -278,7 +278,7 @@ def forward_patch( use_cache, output_attentions, output_hidden_states, - return_dict, + # return_dict, pixel_values, pixel_values_videos, image_grid_thw, diff --git a/roll/platforms/rocm.py b/roll/platforms/rocm.py index 229719b4b..c55b59a84 100644 --- a/roll/platforms/rocm.py +++ b/roll/platforms/rocm.py @@ -40,9 +40,12 @@ def get_custom_env_vars(cls) -> dict: # "VLLM_ROCM_USE_AITER_ASMMOE":"1", # "VLLM_ROCM_USE_AITER_PAGED_ATTN":"1", # "RAY_DEBUG": "legacy", - "VLLM_USE_V1": "0", + "VLLM_USE_V1": "1", "TORCHINDUCTOR_COMPILE_THREADS": "2", "PYTORCH_HIP_ALLOC_CONF": "expandable_segments:True", + "SAFETENSORS_FAST_GPU":"1", + "VLLM_ROCM_USE_AITER_MHA":"0", + "VLLM_ALLOW_LONG_MAX_MODEL_LEN":"1", # "NCCL_DEBUG_SUBSYS":"INIT,COLL", # "NCCL_DEBUG":"INFO", # "NCCL_DEBUG_FILE":"rccl.%h.%p.log", From 41fe2740ad8b75fa61203cef70cc5eef0fbb2436 Mon Sep 17 00:00:00 2001 From: gs450068 Date: Fri, 28 Nov 2025 15:26:10 +0800 Subject: [PATCH 14/58] fix: fix tokenizer usage in llm judge reward worker. --- roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py index e3066c69b..955aa9419 100644 --- a/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py +++ b/roll/pipeline/rlvr/rewards/llm_judge_reward_worker.py @@ -50,6 +50,7 @@ def __init__(self, worker_config: WorkerConfig): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def initialize(self, pipeline_config): super().initialize(pipeline_config) + self.actor_tokenizer = default_tokenizer_provider(pipeline_config.actor_train.model_args) if self.judge_model_type == "api": self.tokenizer = default_tokenizer_provider(model_args=self.worker_config.model_args) print(f"{self.worker_name} initialized with API model") @@ -220,8 +221,8 @@ def compute_rewards(self, data: DataProto): return self._compute_rewards_impl(data, metrics) def _compute_rewards_impl(self, data: DataProto, metrics: Dict): - prompts_text_list = self.tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=True) - response_text_list = self.tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=True) + prompts_text_list = self.actor_tokenizer.batch_decode(data.batch["prompts"], skip_special_tokens=True) + response_text_list = self.actor_tokenizer.batch_decode(data.batch["responses"], skip_special_tokens=True) scores = [] for prompt_id, prompt_txt, response, reference in zip( From 38bfc2e7ed61acc34460a7eeb63191ea41e0252b Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:09:04 +0800 Subject: [PATCH 15/58] (feat): add vlm option. --- roll/datasets/collator.py | 2 +- roll/distributed/scheduler/generate_scheduler.py | 5 +++++ roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/roll/datasets/collator.py b/roll/datasets/collator.py index bcaddc545..8eba22ac1 100644 --- a/roll/datasets/collator.py +++ b/roll/datasets/collator.py @@ -143,7 +143,7 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: else None, text=feature[self.prompt_key], ) - for key in ["prompt"]: # remove non-tensor feature, e.g. tbstars2_moe_vista has prompt in processor output + for key in ["prompt", "position_ids", "rope_deltas"]: # remove unnecessary feature if key in model_inputs: model_inputs.pop(key) for key in filter(lambda k: k in model_inputs, self.padded_keys): diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index 8b6c5c06b..617539984 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -386,6 +386,7 @@ def __init__(self, pipeline_config=None): self.collect_fn_kwargs = None self.collect_fn = None self.tokenizer = None + self.processor = None self.response_filter_fn = None self.query_filter_fn = None self.response_callback_fn = None @@ -413,6 +414,7 @@ def set_scheduler( response_callback_fn=None, state: Dict[str, Any] = None, is_val: bool = False, + is_vlm: bool = False, ): """ GenerateScheduler可以由多个实例,不再局限于单例 @@ -439,6 +441,9 @@ def set_scheduler( self.collect_fn_cls = collect_fn_cls self.collect_fn_kwargs = collect_fn_kwargs self.tokenizer = default_tokenizer_provider(model_args=self.actor_cluster.worker_config.model_args) + self.processor = default_processor_provider(model_args=self.actor_cluster.worker_config.model_args) + if is_vlm: + collect_fn_kwargs["processor"] = self.processor self.collect_fn = self.collect_fn_cls(tokenizer=self.tokenizer, **self.collect_fn_kwargs) if self.is_use_additional_prompts: diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index eb689c1c8..0405e2b29 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -367,7 +367,6 @@ def __init__(self, pipeline_config: RLVRConfig): collect_fn_kwargs=dict( # tokenizer passed by DynamicSamplingScheduler.set_scheduler # tokenizer=self.tokenizer, - processor=self.processor, extra_unpadded_keys=["domain", "reward_model"], extra_data_provider=get_extra_data_provider( self.pipeline_config.actor_train.model_args.model_name_or_path, processor=self.processor @@ -383,6 +382,7 @@ def __init__(self, pipeline_config: RLVRConfig): query_filter_fn=query_filter_fn, response_callback_fn=generate_scheduler.report_response.remote, state=self.state.kv.get(f"scheduler_state_{domain}", None), + is_vlm=True, ) ) self.generate_schedulers[domain] = generate_scheduler @@ -411,7 +411,6 @@ def __init__(self, pipeline_config: RLVRConfig): collect_fn_kwargs=dict( # tokenizer passed by DynamicSamplingScheduler.set_scheduler # tokenizer=self.tokenizer, - processor=self.processor, # val metrics are grouped by tag rather than domain extra_unpadded_keys=["domain", "reward_model", "tag"], extra_data_provider=get_extra_data_provider( @@ -427,6 +426,7 @@ def __init__(self, pipeline_config: RLVRConfig): response_filter_fn=lambda data_item, config: True, query_filter_fn=lambda data_list, config: True, response_callback_fn=self.val_generate_scheduler.report_response.remote, + is_vlm=True, ) ) From 8e4bf7c34528b11ff37617423879ba200ff1fecc Mon Sep 17 00:00:00 2001 From: "weixun.wwx" Date: Thu, 30 Oct 2025 11:14:47 +0800 Subject: [PATCH 16/58] (feat): agentic-spec actor worker. --- roll/pipeline/agentic/agentic_actor_worker.py | 84 +++++++++++++++++++ roll/pipeline/agentic/agentic_config.py | 2 +- 2 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 roll/pipeline/agentic/agentic_actor_worker.py diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py new file mode 100644 index 000000000..f748096bc --- /dev/null +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -0,0 +1,84 @@ +import numpy as np +import torch + +from roll.distributed.scheduler.protocol import DataProto +from roll.pipeline.base_worker import ActorWorker as BaseActorWorker +from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl + + +class ActorWorker(BaseActorWorker): + def loss_func(self, data: DataProto, output_tensor: torch.Tensor): + """ + loss func接口定义: + data: DataProto, 由train_step透传 + output_tensor: torch.Tensor, model.forward()的输出Tensor + """ + response_mask = data.batch["response_mask"][:, 1:].long() + ref_log_probs = data.batch["ref_log_probs"] + old_log_probs = data.batch["old_log_probs"] + advantages = data.batch["advantages"] + + log_probs = self.strategy.op_compute_log_probs( + logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] + ) + + ratio = (log_probs - old_log_probs).exp() + + pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip + pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip + surr1 = ratio * advantages + surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages + pg_loss = -torch.min(surr1, surr2) + if self.pipeline_config.dual_clip_loss: + dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) + pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, + kl_penalty="k3") + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + + approxkl = compute_approx_kl( + log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" + ) + policykl = compute_approx_kl( + log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl" + ) + clipped_low = (ratio < 1 - pg_clip_low).float() + clipped_high = (ratio > 1 + pg_clip_high).float() + clipped = (clipped_low + clipped_high).float() + + if self.pipeline_config.use_kl_loss: + total_loss = pg_loss + kl_loss * self.pipeline_config.kl_loss_coef + else: + total_loss = pg_loss + if self.pipeline_config.entropy_loss_coef > 0: + entropy = self.strategy.op_compute_entropy(logits=output_tensor, attention_mask=data.batch["response_mask"]) + entropy_loss = agg_loss( + loss_mat=entropy, + loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode, + ) + total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef + + pg_metrics = { + "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), + "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), + "actor/ppo_ratio_clipfrac": clipped.mean().detach().item(), + "actor/ratio_mean": masked_mean(ratio, response_mask, dim=-1).mean().detach().item(), + "actor/ratio_max": torch.max(ratio * response_mask).detach().item(), + "actor/ratio_min": torch.min(ratio * response_mask + (1 - response_mask) * 1e10).detach().item(), + "actor/clipfrac": agg_loss(loss_mat=torch.lt(surr2, surr1).float(), loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), + "actor/pg_loss": pg_loss.detach().item(), + "actor/kl_loss": kl_loss.detach().item(), + "actor/total_loss": total_loss.detach().item(), + "actor/approxkl": agg_loss(loss_mat=approxkl, loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), + "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), + } + + return total_loss, pg_metrics + diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index f9c19c302..f421d74cf 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -103,7 +103,7 @@ def __post_init__(self): # default worker_cls if self.actor_train.worker_cls is None: - self.actor_train.worker_cls = "roll.pipeline.base_worker.ActorWorker" + self.actor_train.worker_cls = "roll.pipeline.agentic.agentic_actor_worker.ActorWorker" if self.actor_infer.worker_cls is None: self.actor_infer.worker_cls = "roll.pipeline.base_worker.ActorWorker" if self.reference.worker_cls is None: From 79af5c37232a9717e8fad892419c303d937d27f9 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Tue, 2 Dec 2025 15:53:08 +0800 Subject: [PATCH 17/58] (feat): agentic_filter_task. --- roll/datasets/global_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/roll/datasets/global_dataset.py b/roll/datasets/global_dataset.py index 67d7dde82..8e1338d0f 100644 --- a/roll/datasets/global_dataset.py +++ b/roll/datasets/global_dataset.py @@ -61,8 +61,10 @@ async def reset(self): async def filter(self, filter_name: str, function: Optional[Callable] = None, **kwargs): if filter_name in self.filter_names: return + logger.info(f"---- before filter-- {filter_name}, dataset_name: {self.dataset_name} len: {len(self.dataset)}") self.dataset = self.dataset.filter(function, **kwargs) self.filter_names.add(filter_name) + logger.info(f"---- after filter-- {filter_name}, dataset_name: {self.dataset_name} len: {len(self.dataset)}") @ray.remote From 7c261c8e6ec3ddbda50d601eef0f2b71b97ad403 Mon Sep 17 00:00:00 2001 From: "weixun.wwx" Date: Fri, 31 Oct 2025 10:25:11 +0800 Subject: [PATCH 18/58] (refactor): agentic pipeline modify. --- roll/pipeline/agentic/agentic_pipeline.py | 37 +++++++++++------------ roll/pipeline/agentic/utils.py | 27 +++++++++++++++++ roll/utils/functionals.py | 3 +- 3 files changed, 47 insertions(+), 20 deletions(-) diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 4094fc7d5..b85024f3b 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -17,7 +17,7 @@ from roll.models.model_providers import default_tokenizer_provider from roll.pipeline.agentic.agentic_config import AgenticConfig, EnvManagerConfig from roll.pipeline.agentic.utils import (dump_rollout_render, compute_discounted_returns, - compute_response_level_rewards, dump_rollout_trajectories) + compute_response_level_rewards, dump_rollout_trajectories, get_agentic_response_level_mask) from roll.pipeline.base_pipeline import BasePipeline from roll.utils.constants import RAY_NAMESPACE from roll.utils.functionals import ( @@ -28,6 +28,7 @@ RunningMoments, compute_clip_fraction, agg_loss, + compute_token_reward, ) from roll.utils.kl_controller import get_kl_controller from roll.utils.logging import get_logger @@ -206,30 +207,30 @@ def run(self): metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) metrics["time/old_log_probs_values"] = cal_old_logpb_timer.last - - with Timer(name="adv", logger=None) as timer: + + # TODO 当前这个还没用处 + with Timer(name="cal_response_level_mask", logger=None) as timer: + # TODO 补充完善的过滤要求,不同环境需要维持统一过滤标识 + batch, mask_metrics = get_agentic_response_level_mask(batch, self.pipeline_config) + metrics.update(mask_metrics) + metrics["time/cal_response_level_mask"] = timer.last + + with Timer(name="cal_response_norm_rewards", logger=None) as timer: # Rewards need to be processed after grouping # We can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to compute rewards / advantages # The compute_response_level_rewards function injects a response_level_rewards key into batch.batch. batch = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics["time/cal_norm_rewards"] = timer.last - if self.pipeline_config.reward_clip: - reward_clip_frac = compute_clip_fraction( - values=batch.batch["response_level_rewards"], - clip_max=self.pipeline_config.reward_clip, - clip_min=-self.pipeline_config.reward_clip, - ) - metrics["critic/reward_clip_frac"] = reward_clip_frac - batch.batch["response_level_rewards"] = torch.clamp( - batch.batch["response_level_rewards"], - min=-self.pipeline_config.reward_clip, - max=self.pipeline_config.reward_clip, - ) - + with Timer(name="cal_token_reward", logger=None) as timer: # Expand compute_response_level_rewards and add kl_penalty. - batch, kl_metrics = apply_kl_penalty(data=batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.pipeline_config.kl_penalty) + # batch, kl_metrics = apply_kl_penalty(data=batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.pipeline_config.kl_penalty) + batch, token_level_metrics = compute_token_reward(batch, self.pipeline_config, self.kl_ctrl) + metrics.update(token_level_metrics) + metrics["time/cal_token_reward"] = timer.last + with Timer(name="compute_advantage", logger=None) as timer: # Is the advantage calculated globally across the batch, or within each group? batch = compute_advantage( data=batch, @@ -241,8 +242,6 @@ def run(self): whiten_rewards=self.pipeline_config.whiten_rewards, ) metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - - metrics.update(kl_metrics) metrics["time/adv"] = timer.last if self.pipeline_config.adv_estimator == "gae": diff --git a/roll/pipeline/agentic/utils.py b/roll/pipeline/agentic/utils.py index 14f573d2e..66e49f9b1 100644 --- a/roll/pipeline/agentic/utils.py +++ b/roll/pipeline/agentic/utils.py @@ -181,6 +181,33 @@ def compute_response_level_rewards(batch: "DataProto", pipeline_config: AgenticC return batch +@torch.no_grad() +def get_agentic_response_level_mask(data: "DataProto", pipeline_config: AgenticConfig): + batch_size = data.batch["response_mask"].size(0) + mask_metrics = {} + + # mask相关策略 + data.batch["origin_response_mask"] = data.batch["response_mask"].clone() + response_mask = data.batch["response_mask"][:, 1:].clone() + + final_sample_mask = torch.ones(batch_size, device=response_mask.device) + + if pipeline_config.max_len_mask: + # TODO 当前是混合多个的action/state,需要去判别,或者用别的方式过滤 + final_sample_mask = final_sample_mask + mask_metrics["actor/max_len_mask_ratio"] = 1.0 + else: + mask_metrics["actor/max_len_mask_ratio"] = 1.0 + + expanded_sample_mask = final_sample_mask.unsqueeze(-1).expand_as(response_mask) + final_response_mask = response_mask * expanded_sample_mask + mask_metrics["actor/final_mask_ratio"] = final_sample_mask.mean().item() + mask_metrics["actor/samples_used"] = final_sample_mask.sum().item() + mask_metrics["actor/samples_total"] = float(batch_size) + + data.batch["final_response_mask"] = final_response_mask + return data, mask_metrics + print_only_once = False diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 187f94fcd..4097ce30c 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -10,6 +10,7 @@ from tensordict import TensorDict from roll.pipeline.rlvr.rlvr_config import RLVRConfig +from roll.configs.base_config import PPOConfig from roll.platforms import current_platform from roll.utils.kl_controller import AdaptiveKLController from roll.utils.logging import get_logger @@ -504,7 +505,7 @@ def difficulty_mask(data: "DataProto", n_sample=-1, low_threshold=0.1, high_thre @torch.no_grad() -def compute_token_reward(data: "DataProto", pipeline_config: RLVRConfig, kl_ctrl: AdaptiveKLController): +def compute_token_reward(data: "DataProto", pipeline_config: PPOConfig, kl_ctrl: AdaptiveKLController): token_level_rewards = expand_to_token_level(data) beta = 0 kld = compute_approx_kl( From 28c3edd237abd64e175b3a903702492b158da1be Mon Sep 17 00:00:00 2001 From: "hongzhen.yj" Date: Fri, 31 Oct 2025 16:53:59 +0800 Subject: [PATCH 19/58] (fix): update error logging for image loading failure. --- roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py index de38bd157..bd812c1d0 100644 --- a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py @@ -62,7 +62,7 @@ def encode_function(data_i, processor, prompt_key, answer_key, image_key): image_out = load_images(image if isinstance(image, (list, tuple)) else [image], timeout=None) except Exception as e: image_out = [Image.new("RGB", (224, 224), (255, 255, 255))] - logger.error(f"Failed to get image: {image}") + logger.error(f"Failed to get image due to {e}") # since infer-image use pil image as input while train-engine use # processed data, process image here to make them use same image image_out = process_images(image_out, processor) From 703804011074f51dae45ce9bc42f25eab20aa2ce Mon Sep 17 00:00:00 2001 From: "weixun.wwx" Date: Fri, 31 Oct 2025 17:18:36 +0800 Subject: [PATCH 20/58] (fix): fix max_len_mask key. --- roll/pipeline/agentic/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roll/pipeline/agentic/utils.py b/roll/pipeline/agentic/utils.py index 66e49f9b1..d7b86f4ee 100644 --- a/roll/pipeline/agentic/utils.py +++ b/roll/pipeline/agentic/utils.py @@ -192,7 +192,7 @@ def get_agentic_response_level_mask(data: "DataProto", pipeline_config: AgenticC final_sample_mask = torch.ones(batch_size, device=response_mask.device) - if pipeline_config.max_len_mask: + if getattr(pipeline_config, "max_len_mask", False): # TODO 当前是混合多个的action/state,需要去判别,或者用别的方式过滤 final_sample_mask = final_sample_mask mask_metrics["actor/max_len_mask_ratio"] = 1.0 From aa6ad590ecb9d12509b0bbb6fe185614baf3ec2a Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Tue, 2 Dec 2025 15:54:50 +0800 Subject: [PATCH 21/58] (feat): add infer_log_probs in agentic. --- roll/distributed/scheduler/generate_scheduler.py | 1 + roll/pipeline/agentic/agentic_actor_worker.py | 11 +++++++++++ .../agentic/env_manager/step_env_manager.py | 13 ++++++++++--- .../agentic/env_manager/traj_env_manager.py | 13 +++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index 617539984..d968eb5aa 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -918,6 +918,7 @@ async def generate_one_request(self, data: DataProto): eos_token_id=eos_token_id, pad_token_id=pad_token_id, pad_to_seq_len=data.meta_info.get("pad_to_seq_len", True), + output_logprobs=output_logprobs, ) request_repeat = data.repeat(repeat_times=len(output_tokens)) output.non_tensor_batch = request_repeat.non_tensor_batch diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index f748096bc..c1e273e3e 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -16,6 +16,9 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) + infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -23,6 +26,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ) ratio = (log_probs - old_log_probs).exp() + train_infer_ratio = (log_probs - infer_log_probs).exp() + train_infer_diff = log_probs.exp() - infer_log_probs.exp() pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip @@ -62,6 +67,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ) total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef + train_infer_prob_metric = { + "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), + "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), + } + pg_metrics = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -78,6 +88,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), + **train_infer_prob_metric } return total_loss, pg_metrics diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 737910394..4348605a3 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -79,6 +79,9 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) + infer_logprobs = [] + if "infer_logprobs" in history: + infer_logprobs = [0] * len(history["prompt_ids"]) + history["infer_logprobs"] first_response_idx = response_masks.index(1) prompt_masks = [1] * first_response_idx + [0] * (len(token_ids) - first_response_idx) @@ -93,8 +96,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) prompt_mask = pad_to_length(prompt_mask, length=self.pipeline_config.sequence_length, pad_value=0) score_tensor = pad_to_length(score_tensor, length=self.pipeline_config.sequence_length, pad_value=0) - - samples.append(DataProto( + lm_input = DataProto( batch=TensorDict( { "input_ids": input_ids, @@ -114,8 +116,13 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "state_hash": np.array([history['state_hash']], dtype=object), "step": np.array([step], dtype=object), } - )) + ) + if len(infer_logprobs): + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) + lm_input.batch["infer_logprobs"] = infer_logprobs[:, 1:] + samples.append(lm_input) batch: DataProto = DataProto.concat(samples) response_length = batch.batch["response_mask"].float().sum(-1).mean().item() diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 050594636..88ab15d91 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -225,6 +225,11 @@ def make_decision(self, rollout_cache: RolloutCache): response_ids = lm_output.batch['responses'][0] response_ids = response_ids.tolist() content = self.rollout_cache.history[-1] + + if "infer_logprobs" in lm_output.batch: + infer_logprobs = lm_output.batch['infer_logprobs'][0][-len(response_ids):] + content["infer_logprobs"] = infer_logprobs.tolist() + content["response_ids"] = response_ids content["messages"].append({"role": "assistant", "content": self.tokenizer.decode(response_ids, skip_special_tokens=True)}) lm_output.meta_info["stop_reason"] = GenerateStopReason.FINISH @@ -293,11 +298,14 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): token_ids = [] prompt_masks = [] response_masks = [] + infer_logprobs = [] for items in self.rollout_cache.history: token_ids.extend(items["prompt_ids"]) token_ids.extend(items["response_ids"]) prompt_masks.extend([1] * len(items["prompt_ids"]) + [0] * len(items["response_ids"])) response_masks.extend([0] * len(items["prompt_ids"]) + [1] * len(items["response_ids"])) + if "infer_logprobs" in items: + infer_logprobs.extend([0] * len(items["prompt_ids"]) + items["infer_logprobs"]) input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) @@ -337,6 +345,11 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "prompt_mask": prompt_mask, "scores": score_tensor, }) + if len(infer_logprobs): + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) + lm_input.batch["infer_logprobs"] = infer_logprobs[:, 1:] + lm_input.non_tensor_batch.update({ "env_ids": np.array([self.rollout_cache.env_id], dtype=object), "group_ids": np.array([self.rollout_cache.group_id], dtype=object), From 48c2253f03ae46f6b390d94e70a17f8ea9f736ae Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:11:30 +0800 Subject: [PATCH 22/58] (feat): update mcore_adapter. --- mcore_adapter/examples/train/run_train.py | 32 ++ .../src/mcore_adapter/adapters/__init__.py | 35 ++ .../src/mcore_adapter/adapters/lora_layer.py | 536 ++++++++++++++++++ .../src/mcore_adapter/adapters/utils.py | 50 ++ .../models/converter/dist_converter.py | 85 ++- .../models/converter/post_converter.py | 69 ++- .../models/converter/template.py | 61 +- .../src/mcore_adapter/models/model_config.py | 16 +- .../src/mcore_adapter/models/model_factory.py | 41 +- .../src/mcore_adapter/models/model_utils.py | 3 +- mcore_adapter/src/mcore_adapter/utils.py | 4 +- 11 files changed, 897 insertions(+), 35 deletions(-) create mode 100644 mcore_adapter/src/mcore_adapter/adapters/__init__.py create mode 100644 mcore_adapter/src/mcore_adapter/adapters/lora_layer.py create mode 100644 mcore_adapter/src/mcore_adapter/adapters/utils.py diff --git a/mcore_adapter/examples/train/run_train.py b/mcore_adapter/examples/train/run_train.py index 9f38771ec..a77cdf173 100644 --- a/mcore_adapter/examples/train/run_train.py +++ b/mcore_adapter/examples/train/run_train.py @@ -16,9 +16,11 @@ from llamafactory.train.dpo import run_dpo from llamafactory.train.pt import run_pt from llamafactory.train.sft import run_sft +from peft import LoraConfig, get_peft_model from transformers import DataCollatorForSeq2Seq, HfArgumentParser from transformers.trainer_callback import TrainerCallback +from mcore_adapter.adapters import apply_megatron_lora, find_all_linear_modules, set_linear_is_expert from mcore_adapter.models import AutoConfig, AutoModel from mcore_adapter.trainer import DPOTrainer, McaTrainer from mcore_adapter.trainer.dpo_config import DPOConfig @@ -110,6 +112,28 @@ def wrapper(features: Sequence[Dict[str, Any]]): return wrapper +def setup_lora_training(model, finetuning_args): + model.enable_input_require_grads() + target_modules = find_all_linear_modules(model) + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + } + + lora_config = LoraConfig( + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + return model + + def pt_mca_train( training_args: Seq2SeqTrainingArguments, model_args: ModelArguments, @@ -120,6 +144,10 @@ def pt_mca_train( tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + if finetuning_args.finetuning_type == "lora": + apply_megatron_lora() + set_linear_is_expert(model[0]) + model.models[0] = setup_lora_training(model[0], finetuning_args) data_args.cutoff_len += 1 dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) data_args.cutoff_len -= 1 @@ -165,6 +193,10 @@ def sft_mca_train( data_args.cutoff_len += 1 dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) data_args.cutoff_len -= 1 + if finetuning_args.finetuning_type == "lora": + apply_megatron_lora() + set_linear_is_expert(model[0]) + model.models[0] = setup_lora_training(model[0], finetuning_args) if model.config.hf_model_type in ["qwen2_vl"] and finetuning_args.freeze_vision_tower: for name, p in model.named_parameters(): if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]): diff --git a/mcore_adapter/src/mcore_adapter/adapters/__init__.py b/mcore_adapter/src/mcore_adapter/adapters/__init__.py new file mode 100644 index 000000000..b223260a7 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/adapters/__init__.py @@ -0,0 +1,35 @@ +from ..utils import get_logger, is_peft_available + + +logger = get_logger(__name__) + +if is_peft_available(): + from .lora_layer import apply_megatron_lora + from .utils import ( + find_all_embedding_modules, + find_all_linear_modules, + find_all_router_modules, + set_linear_is_expert, + ) +else: + + def apply_megatron_lora(*args, **kwargs): + raise ValueError("PEFT is not available. Please install PEFT to use LoRA adapters.") + + def find_all_linear_modules(model): + raise ValueError("PEFT is not available. Please install PEFT to use LoRA adapters.") + + def find_all_embedding_modules(model): + raise ValueError("PEFT is not available. Please install PEFT to use LoRA adapters.") + + def find_all_router_modules(model): + raise ValueError("PEFT is not available. Please install PEFT to use LoRA adapters.") + + +__all__ = [ + "apply_megatron_lora", + "find_all_linear_modules", + "find_all_embedding_modules", + "find_all_router_modules", + "set_linear_is_expert", +] diff --git a/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py new file mode 100644 index 000000000..ad7630a44 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py @@ -0,0 +1,536 @@ +import math +from contextlib import contextmanager +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TEColumnParallelLinear, + TEGroupedLinear, + TELayerNormColumnParallelLinear, + TELinear, + TERowParallelGroupedLinear, + TERowParallelLinear, +) +from megatron.core.parallel_state import ( + get_expert_model_parallel_rank, + get_expert_model_parallel_world_size, + get_expert_tensor_parallel_world_size, + get_tensor_model_parallel_world_size, +) +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region +from megatron.core.transformer.mlp import apply_swiglu_sharded_factory +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default +from peft.tuners.lora import model +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils import transpose + + +class LoraParallelLinear(MegatronModule, LoraLayer): + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + lora_bias: bool = False, + **kwargs, + ): + config = base_layer.config + super().__init__(config=config) + LoraLayer.__init__(self, base_layer=base_layer) + + # lora needs to be forced to upgrade to 32-bit precision, otherwise it will overflow + self.config.params_dtype = torch.float32 + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + self.is_grouped = isinstance(base_layer, TEGroupedLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.is_expert = getattr(base_layer, "is_expert", False) + self.sequence_parallel = getattr(base_layer, "sequence_parallel", False) + if self.is_expert: + self.tp_size = get_expert_tensor_parallel_world_size() + else: + self.tp_size = get_tensor_model_parallel_world_size() + + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + lora_bias=lora_bias, + ) + + self.is_target_conv_1d_layer = False + + def _create_lora_layers(self, r, lora_bias, **kwargs): + """Create LoRA A and B layers. To be implemented by subclasses.""" + raise NotImplementedError("_create_lora_layers must be implemented in subclasses") + + def update_layer( + self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_weights, use_rslora, lora_bias, **kwargs + ): + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + # Create LoRA layers based on subclass implementation + lora_layer_kwargs = { + "skip_bias_add": False, + "init_method": self.config.init_method, + "config": self.config, + "is_expert": self.is_expert, + "tp_group": self.base_layer.tp_group + } + lora_a, lora_b = self._create_lora_layers(r, lora_bias, **lora_layer_kwargs) + + # Disable ub_overlap for parallel layers + for lora in [lora_a, lora_b]: + if isinstance(lora, (TERowParallelLinear, TEColumnParallelLinear)) and lora.parallel_mode is None: + lora.ub_overlap_rs_fprop = False + lora.ub_overlap_ag_dgrad = False + lora.ub_overlap_ag_fprop = False + lora.ub_overlap_rs_dgrad = False + + # Disable sequence parallel for LoRA layers + lora_a.sequence_parallel = False + lora_b.sequence_parallel = False + + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + + if hasattr(self, "lora_bias"): + self.lora_bias[adapter_name] = lora_bias + if use_rslora: + self.scaling[adapter_name] = lora_alpha / (r**0.5) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name, init_lora_weights): + if init_lora_weights is False: + return + + if adapter_name in self.lora_A.keys(): + lora_a = self.lora_A[adapter_name] + lora_b = self.lora_B[adapter_name] + if isinstance(lora_a, TEGroupedLinear): + weights_a = [getattr(lora_a, f"weight{i}") for i in range(lora_a.num_gemms)] + else: + weights_a = [lora_a.weight] + if isinstance(lora_b, TEGroupedLinear): + weights_b = [getattr(lora_b, f"weight{i}") for i in range(lora_b.num_gemms)] + else: + weights_b = [lora_b.weight] + for weight_a in weights_a: + if init_lora_weights is True: + # initialize A the same way as the default for nn.Linear and B to zero + # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124 + nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5)) + elif init_lora_weights.lower() == "gaussian": + nn.init.normal_(weight_a, std=1 / self.r[adapter_name]) + else: + raise ValueError(f"Unknown initialization {init_lora_weights=}") + for weight_b in weights_b: + nn.init.zeros_(weight_b) + if adapter_name in self.lora_embedding_A.keys(): + # Initialize A to zeros and B the same way as the default for nn.Embedding, see: + # https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L59-L60 + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + @contextmanager + def _patch_router_gating(self): + origin_gating = self.base_layer.__class__.gating + + def gating(_self, x): + result = origin_gating(_self, x) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(result.dtype) + + lora_result = F.linear(dropout(x), lora_A.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = F.linear(lora_result, lora_B.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + return result + + self.base_layer.__class__.gating = gating + try: + yield + finally: + self.base_layer.__class__.gating = origin_gating + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + previous_dtype = x.dtype + if self.disable_adapters and self.merged: + self.unmerge() + + if isinstance(self.base_layer, TELayerNormColumnParallelLinear): + if self.disable_adapters or self.merged: + self.base_layer.return_layernorm_output = False + result, bias = self.base_layer(x, *args, **kwargs) + else: + self.base_layer.return_layernorm_output = True + (result, x), bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)): + result, bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, TopKRouter): + with self._patch_router_gating(): + result, bias = self.base_layer(x, *args, **kwargs) + else: + raise ValueError(f"Unsupported base layer type: {type(self.base_layer)}") + + if not isinstance(self.base_layer, TopKRouter) and not self.disable_adapters and not self.merged: + if self.sequence_parallel and self.base_layer.parallel_mode == "column": + x = gather_from_sequence_parallel_region(x) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + dtype = lora_A.weight0.dtype if isinstance(lora_A, TEGroupedLinear) else lora_A.weight.dtype + x = x.to(dtype) + + lora_result = ( + lora_A(dropout(x), *args, **kwargs) if isinstance(lora_A, TEGroupedLinear) else lora_A(dropout(x)) + ) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = ( + lora_B(lora_result, *args, **kwargs) + if isinstance(lora_B, TEGroupedLinear) + else lora_B(lora_result) + ) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + if self.sequence_parallel and self.base_layer.parallel_mode == "row": + lora_result = scatter_to_sequence_parallel_region(lora_result) + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + """ + Merge the active adapter weights into the base weights + + Args: + safe_merge (`bool`, *optional*): + If True, the merge operation will be performed in a copy of the original weights and check for NaNs + before merging the weights. This is useful if you want to check if the merge operation will produce + NaNs. Defaults to `False`. + adapter_names (`list[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == "cpu": + self.to(device=torch.cuda.current_device()) + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f"weight{i}") for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + if safe_merge: + # Note that safe_merge will be slower than the normal merge + # because of the copy operation. + orig_weights = [weight.data.clone() for weight in orig_weights] + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + orig_weight += delta_weight + if not all(torch.isfinite(orig_weights[i]).all() for i in range(len(orig_weights))): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + if self.is_grouped: + for i in range(base_layer.num_gemms): + weight = getattr(base_layer, f"weight{i}") + weight.data = orig_weights[i] + else: + base_layer.weight.data = orig_weights[0] + else: + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + orig_weight.data += delta_weight + self.merged_adapters.append(active_adapter) + + if origin_device.type == "cpu": + self.to(device=origin_device) + + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + sharded_state_dict = {} + # Save parameters + self._save_to_state_dict(sharded_state_dict, "", keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, prefix, sharded_offsets=sharded_offsets + ) + # Recurse into submodules + for name, module in self.named_children(): + if "Dict" in module.__class__.__name__: + modules = module.named_children() + else: + modules = [(None, module)] + for n, m in modules: + _prefix = f"{prefix}{name}." if n is None else f"{prefix}{name}.{n}." + sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata)) + + if prefix.endswith("linear_fc1."): + if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit: + num_global_experts = get_expert_model_parallel_world_size() * self.base_layer.num_gemms + local_expert_indices_offset = get_expert_model_parallel_rank() * self.base_layer.num_gemms + ep_axis = len(sharded_offsets) + for i in range(self.base_layer.num_gemms): + new_sharded_offsets = ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + i, num_global_experts), + ) + for k in (f"{prefix}base_layer.weight{i}", f"{prefix}base_layer.bias{i}"): + if k in sharded_state_dict: + sharded_state_dict[k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], new_sharded_offsets + ) + else: + for k, v in sharded_state_dict.items(): + if k in [f"{prefix}base_layer.weight", f"{prefix}base_layer.bias"]: + sharded_state_dict[k] = apply_swiglu_sharded_factory(sharded_state_dict[k], sharded_offsets) + return sharded_state_dict + + def get_delta_weights(self, adapter) -> list[torch.Tensor]: + """ + Compute the delta weight for the given adapter. + + Args: + adapter (str): + The name of the adapter for which the delta weight should be computed. + """ + lora_A = self.lora_A[adapter] + lora_B = self.lora_B[adapter] + if self.is_grouped: + weight_A = [getattr(lora_A, f"weight{i}") for i in range(lora_A.num_gemms)] + weight_B = [getattr(lora_B, f"weight{i}") for i in range(lora_B.num_gemms)] + else: + weight_A = [self.lora_A[adapter].weight] + weight_B = [self.lora_B[adapter].weight] + output_tensor = [] + assert len(weight_A) == len(weight_B) + for i in range(len(weight_B)): + output_tensor.append(transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * self.scaling[adapter]) + + return output_tensor + + +class LoraRouterParallelLinear(LoraParallelLinear): + """LoRA layer for TopKRouter""" + + def _create_lora_layers(self, r, lora_bias, **kwargs): + router_shape = self.base_layer.weight.shape + lora_a = TELinear( + input_size=router_shape[1], + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=router_shape[0], + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + return lora_a, lora_b + + +class LoraRowParallelLinear(LoraParallelLinear): + """LoRA layer for row parallel linear layers""" + + def _create_lora_layers(self, r, lora_bias, **kwargs): + in_features = self.in_features * self.tp_size + + if self.is_grouped: + lora_a = TERowParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=in_features, + output_size=r, + bias=False, + **kwargs, + ) + lora_b = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + **kwargs, + ) + else: + lora_a = TERowParallelLinear( + input_size=in_features, + output_size=r, + bias=False, + input_is_parallel=True, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_a.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap + + return lora_a, lora_b + + +class LoraColumnParallelLinear(LoraParallelLinear): + """LoRA layer for column parallel linear layers""" + + def _create_lora_layers(self, r, lora_bias, **kwargs): + out_features = self.out_features * self.tp_size + + if self.is_grouped: + lora_a = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + **kwargs, + ) + lora_b = TEColumnParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=out_features, + bias=lora_bias, + **kwargs, + ) + else: + lora_a = TELinear( + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_b = TEColumnParallelLinear( + input_size=r, + output_size=out_features, + bias=lora_bias, + gather_output=False, + **kwargs, + ) + lora_b.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap + + return lora_a, lora_b + + +def dispatch_megatron( + target: torch.nn.Module, + adapter_name: str, + lora_config, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if isinstance(target_base_layer, TopKRouter): + new_module = LoraRouterParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + elif isinstance(target_base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)): + new_module = LoraRowParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + elif isinstance(target_base_layer, (TEColumnParallelLinear, TEColumnParallelGroupedLinear, TELayerNormColumnParallelLinear)): + new_module = LoraColumnParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + elif isinstance(target_base_layer, (TELinear, TEGroupedLinear)): + # default to column parallel linear for non-parallel linear layers + new_module = LoraColumnParallelLinear(base_layer=target, adapter_name=adapter_name, **kwargs) + + return new_module + +def patch_TELinear(): + def __repr__(self): + return ( + f"{type(self).__name__}(in_features={self.in_features}, " + f"out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})" + ) + + setattr(TELinear, "__repr__", __repr__) + + +def patch_TEGroupedLinear(): + def sharded_state_dict( + self, + prefix: str = "", + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ): + return self._sharded_state_dict_grouped(None, prefix, sharded_offsets, metadata) + + setattr(TEGroupedLinear, "sharded_state_dict", sharded_state_dict) + + +def apply_megatron_lora(): + patch_TELinear() + patch_TEGroupedLinear() + model.dispatch_megatron = dispatch_megatron diff --git a/mcore_adapter/src/mcore_adapter/adapters/utils.py b/mcore_adapter/src/mcore_adapter/adapters/utils.py new file mode 100644 index 000000000..f8bde73e8 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/adapters/utils.py @@ -0,0 +1,50 @@ +import re +from typing import Callable + +from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.transformer.moe.router import TopKRouter +from transformers import PreTrainedModel + + +def set_linear_is_expert(model): + for n, module in model.named_modules(): + if ( + ".experts." in n + and isinstance(module, (TELinear, TELayerNormColumnParallelLinear)) + or isinstance(module, TEGroupedLinear) + ): + module.is_expert = True + + +def find_layers(model: "PreTrainedModel", cond: Callable): + inner_nodes = set() + for name, module in model.named_modules(): + name = re.sub(r"\d+\.", "{}.", name) + if not cond(module): + inner_nodes.add(name) + target_module_names = set() + for name, module in model.named_modules(): + if cond(module): + module_name_list = name.split(".") + module_name = module_name_list.pop() + for inner_node in inner_nodes: + processed_module_name = re.sub(r"\d+\.", "{}.", module_name) + while module_name_list and inner_node.endswith(processed_module_name): + module_name = f"{module_name_list.pop()}.{module_name}" + target_module_names.add(module_name) + return list(target_module_names) + + +def find_all_linear_modules(model): + return find_layers( + model, lambda module: isinstance(module, (TELinear, TEGroupedLinear, TELayerNormColumnParallelLinear)) + ) + + +def find_all_embedding_modules(model): + return find_layers(model, lambda module: isinstance(module, LanguageModelEmbedding)) + + +def find_all_router_modules(model): + return find_layers(model, lambda module: isinstance(module, TopKRouter)) diff --git a/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py b/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py index 80917a8da..d79ba3361 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py @@ -1,8 +1,9 @@ import fnmatch import os +import warnings from dataclasses import dataclass, field from itertools import product -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from megatron.core.transformer.pipeline_parallel_layer_layout import LayerType, PipelineParallelLayerLayout @@ -50,6 +51,7 @@ class DistParallelConfig: swiglu_weights: List[str] = field(default_factory=list) # ungrouped TE name to grouped + grouped_duplicated_map: Dict[str, str] = field(default_factory=dict) grouped_column_map: Dict[str, str] = field(default_factory=dict) grouped_row_map: Dict[str, str] = field(default_factory=dict) @@ -57,9 +59,10 @@ class DistParallelConfig: def __post_init__(self): self.local_to_te_key_map = {v: k for k, v in self.te_to_local_key_map.items()} + self.grouped_duplicated_weights = list(self.grouped_duplicated_map.keys()) + list(self.grouped_duplicated_map.values()) self.grouped_column_weights = list(self.grouped_column_map.keys()) + list(self.grouped_column_map.values()) self.grouped_row_weights = list(self.grouped_row_map.keys()) + list(self.grouped_row_map.values()) - self.grouped_map = {**self.grouped_column_map, **self.grouped_row_map} + self.grouped_map = {**self.grouped_duplicated_map, **self.grouped_column_map, **self.grouped_row_map} self.grouped_reverse_map = {v: k for k, v in self.grouped_map.items()} def merge_configs(self, other: "DistParallelConfig") -> "DistParallelConfig": @@ -76,12 +79,36 @@ def merge_configs(self, other: "DistParallelConfig") -> "DistParallelConfig": column_parallel_weights=self.column_parallel_weights + other.column_parallel_weights, row_parallel_weights=self.row_parallel_weights + other.row_parallel_weights, swiglu_weights=self.swiglu_weights + other.swiglu_weights, + grouped_duplicated_map={**self.grouped_duplicated_map, **other.grouped_duplicated_map}, grouped_column_map={**self.grouped_column_map, **other.grouped_column_map}, grouped_row_map={**self.grouped_row_map, **other.grouped_row_map}, te_to_local_key_map={**self.te_to_local_key_map, **other.te_to_local_key_map}, ) +lora_config = DistParallelConfig( + duplicated_weights=[ + ".self_attention.linear_proj.lora_B.*.weight", + ".self_attention.linear_qkv.lora_A.*.weight", + ".mlp.linear_fc1.lora_A.*.weight", + ".linear_fc1.lora_A.*.weight", + ".mlp.linear_fc2.lora_B.*.weight", + ".linear_fc2.lora_B.*.weight", + ], + column_parallel_weights=[ + ".self_attention.linear_qkv.lora_B.*.weight", + ".mlp.linear_fc1.lora_B.*.weight", + ".linear_fc1.lora_B.*.weight", + ], + row_parallel_weights=[ + ".self_attention.linear_proj.lora_A.*.weight", + ".mlp.linear_fc2.lora_A.*.weight", + ".linear_fc2.lora_A.*.weight", + ], + swiglu_weights=[".mlp.linear_fc1.lora_B.*.weight", ".linear_fc1.lora_B.*.weight"], +) + + default_dist_config = DistParallelConfig( pre_process_weights=[MCORE_WORD_EMBEDDING], post_process_weights=[MCORE_LM_HEAD, "decoder.final_layernorm.weight"], @@ -109,13 +136,23 @@ def merge_configs(self, other: "DistParallelConfig") -> "DistParallelConfig": ".self_attention.linear_qkv.layer_norm_weight": ".input_layernorm.weight", ".mlp.linear_fc1.layer_norm_weight": ".pre_mlp_layernorm.weight", }, +).merge_configs(lora_config) + + +lora_te_moe_config = DistParallelConfig( + grouped_duplicated_map={ + ".linear_fc1.lora_A.*.weight": ".mlp.experts.linear_fc1.lora_A.*.weight", + ".linear_fc2.lora_B.*.weight": ".mlp.experts.linear_fc2.lora_B.*.weight", + }, + grouped_column_map={".linear_fc1.lora_B.*.weight": ".mlp.experts.linear_fc1.lora_B.*.weight"}, + grouped_row_map={".linear_fc2.lora_A.*.weight": ".mlp.experts.linear_fc2.lora_A.*.weight"}, ) te_moe_config = DistParallelConfig( grouped_column_map={".linear_fc1.weight": ".mlp.experts.linear_fc1.weight"}, grouped_row_map={".linear_fc2.weight": ".mlp.experts.linear_fc2.weight"}, -) +).merge_configs(lora_te_moe_config) mtp_config = DistParallelConfig( @@ -366,7 +403,9 @@ def get_pure_name(self, name: str): pure_name = remove_mca_weight_prefix(name) if self.use_te_grouped_moe: suffix_num = extract_suffix_number(pure_name) - if suffix_num is not None and pure_name[: -len(suffix_num)] in self.config.grouped_reverse_map: + if suffix_num is not None and self.name_match( + pure_name[: -len(suffix_num)], self.config.grouped_reverse_map + ): pure_name = pure_name[: -len(suffix_num)] if self.mca_config.transformer_impl == "local": if self.revert and pure_name in self.config.local_to_te_key_map: @@ -391,7 +430,7 @@ def name_relocate(self, name: str, vp_stage: int, moe_index: Optional[int] = Non if moe_index is not None: if self.revert: if self.mca_config.moe_grouped_gemm: - pure_name = self.config.grouped_reverse_map[pure_name] + pure_name = self.get_matched_name(pure_name, self.config.grouped_reverse_map) moe_index = self.num_layers_for_expert * self.expert_model_parallel_rank + moe_index else: if self.mca_config.moe_grouped_gemm: @@ -458,6 +497,23 @@ def handle_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]], name = self.name_relocate(name, vp_stage=vp_stage) return {name: weight} + def handle_grouped_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: + if self.revert: + weight = weights[0] + for w in weights[1:]: + if w.equal(weight): + continue + message = f"{name} weights are not equal diff sum: {torch.sum(torch.abs(w - weight))}" + if ASSERT_SP_CONSISTENCY: + raise ValueError(message) + else: + logger.warning(message) + break + else: + raise NotImplementedError() + moe_index = int(extract_suffix_number(name)) + return {self.name_relocate(name, moe_index=moe_index): weight} + def _convert_te_grouped_column(self, name: str, weights: "Tensor", vp_stage: int): if self.swiglu: weights = self._convert_swiglu(weights) @@ -577,7 +633,7 @@ def handle_grouped_row(self, name: str, weights: Union["Tensor", List["Tensor"]] return self._convert_te_grouped_row(name, weights, vp_stage=vp_stage) return self._convert_grouped_row(name, weights, vp_stage=vp_stage) - def name_match(self, pure_name: str, patterns: List[str]): + def name_match(self, pure_name: str, patterns: list[str] | dict[str, Any]): if pure_name in patterns: return True for pattern in patterns: @@ -585,14 +641,24 @@ def name_match(self, pure_name: str, patterns: List[str]): return True return False + def get_matched_name(self, name: str, weight_map: dict[str, Any]) -> Optional[str]: + if name in weight_map: + return weight_map[name] + for key in weight_map: + if fnmatch.fnmatch(name, key): + name_pattern = weight_map[key] + return name_pattern[:name_pattern.find(".lora")] + name[name.find(".lora"):] + def get_local_moe_index(self, name: str) -> Optional[Union[int, List[int]]]: pure_name = remove_mca_weight_prefix(name) if self.use_te_grouped_moe: suffix_num = extract_suffix_number(pure_name) - if suffix_num is not None and pure_name[: -len(suffix_num)] in self.config.grouped_reverse_map: + if suffix_num is not None and self.name_match( + pure_name[: -len(suffix_num)], self.config.grouped_reverse_map + ): return int(suffix_num) if self.mca_config.moe_grouped_gemm: - if pure_name in self.config.grouped_reverse_map: + if self.name_match(pure_name, self.config.grouped_reverse_map): return list(range(self.num_layers_for_expert)) return get_mca_moe_index(name) @@ -624,6 +690,8 @@ def dist_convert(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_s pure_name = self.get_pure_name(name) if pure_name.endswith(".bias"): pure_name = pure_name.replace(".bias", ".weight") + if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_duplicated_weights): + return self.handle_grouped_duplicated(name, weights) if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_column_weights): return self.handle_grouped_column(name, weights, vp_stage=vp_stage) if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_row_weights): @@ -650,6 +718,7 @@ def __call__(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage @staticmethod def dist_converter_iter(mca_config: "McaModelConfig", **kwargs): + warnings.warn("dist_converter_iter is deprecated", DeprecationWarning) for tp_rank, pp_rank, ep_rank in product( range(mca_config.tensor_model_parallel_size), range(mca_config.pipeline_model_parallel_size), diff --git a/mcore_adapter/src/mcore_adapter/models/converter/post_converter.py b/mcore_adapter/src/mcore_adapter/models/converter/post_converter.py index 50274f293..61fc0cfcd 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/post_converter.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/post_converter.py @@ -20,13 +20,15 @@ from ...checkpointing import get_checkpoint_name, save_config_and_state_dict from ...training_args import DistributingParallelArguments -from ...utils import get_logger +from ...utils import get_logger, is_peft_available from ..auto.config_auto import AutoConfig -from .dist_converter import DistConverter from .model_converter import ModelConverter from .template import get_template +if is_peft_available(): + from peft import LoraConfig, PeftConfig, get_peft_model + if TYPE_CHECKING: from ...training_args import DistributingParallelArguments from .template import Template @@ -68,13 +70,27 @@ def log(msg): def convert_checkpoint_to_hf( - model_name_or_path: str, save_directory: str, torch_dtype: Optional["torch.dtype"] = None, verbose: bool = True + model_name_or_path: str, + save_directory: str, + adapter_name_or_path: Optional[str] = None, + torch_dtype: Optional["torch.dtype"] = None, + verbose: bool = True, ): - mca_config = AutoConfig.from_pretrained(model_name_or_path) + if is_lora := adapter_name_or_path is not None: + if not is_peft_available(): + raise ImportError("PEFT is not installed. Please install it with `pip install peft`") + ckpt_path = adapter_name_or_path + peft_config = PeftConfig.from_pretrained(adapter_name_or_path) + else: + ckpt_path = model_name_or_path + mca_config = AutoConfig.from_pretrained(ckpt_path) if mca_config is None: raise ValueError("No mca config found in checkpoint") if mca_config.hf_model_type is None: raise ValueError("No hf model type found in mca config") + if is_lora: + setattr(mca_config, "lora_rank", peft_config.r) + template: "Template" = get_template(mca_config.hf_model_type) hf_config = template.convert_mca_to_hf_config(mca_config) template.set_mca_config_for_ops(mca_config) @@ -93,7 +109,7 @@ def convert_checkpoint_to_hf( # TODO: use loader and support low_mem for tp_rank in range(mca_config.tensor_model_parallel_size): ckpt_name = get_checkpoint_name( - model_name_or_path, + ckpt_path, tensor_rank=tp_rank, pipeline_rank=pp_rank, pipeline_parallel=mca_config.pipeline_model_parallel_size > 1, @@ -134,18 +150,43 @@ def convert_checkpoint_to_hf( else: model_class = _get_model_class(hf_config, model_class._model_mapping) - model = model_class.from_pretrained( - None, - config=hf_config, - state_dict=hf_state_dict, - torch_dtype=torch_dtype if torch_dtype is not None else mca_config.params_dtype, - trust_remote_code=True, - ) + if is_lora: + hf_config.save_pretrained(save_directory) + target_modules = set() + for name, _ in hf_state_dict.items(): + if ".lora_A." in name or ".lora_B." in name: + # TODO: support VLM lora + target_modules.add(name[:name.find(".lora")].split(".")[-1]) + target_modules = list(target_modules) + model = model_class.from_pretrained( + model_name_or_path, + config=hf_config, + torch_dtype=torch_dtype if torch_dtype is not None else mca_config.params_dtype, + trust_remote_code=True, + ) + lora_config = LoraConfig( + r=peft_config.r, + target_modules=target_modules, + lora_alpha=peft_config.lora_alpha, + lora_dropout=peft_config.lora_dropout, + use_rslora=peft_config.use_rslora, + modules_to_save=peft_config.modules_to_save, + ) + model = get_peft_model(model, lora_config) + model.base_model.model.load_state_dict(hf_state_dict, strict=False) + else: + model = model_class.from_pretrained( + None, + config=hf_config, + state_dict=hf_state_dict, + torch_dtype=torch_dtype if torch_dtype is not None else mca_config.params_dtype, + trust_remote_code=True, + ) model.save_pretrained(save_directory) mca_config.save_hf_auto_map_files(save_directory) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True) try: - processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True) except Exception as e: logger.info(f"Processor was not found: {e}.") processor = tokenizer diff --git a/mcore_adapter/src/mcore_adapter/models/converter/template.py b/mcore_adapter/src/mcore_adapter/models/converter/template.py index 45a0ce353..901ba652c 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/template.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/template.py @@ -129,6 +129,21 @@ def _mca_to_hf(self, weights): return weights +@dataclass +class CopyConverOp(ConverOp): + def __post_init__(self): + super().__post_init__() + assert (len(self.hf_names) == 1) != (len(self.mca_names) == 1), ( + f"CopyConverOp only supports one name as target {self.hf_names} {self.mca_names}" + ) + + def _hf_to_mca(self, weights): + return weights * len(self.mca_names) + + def _mca_to_hf(self, weights): + return weights * len(self.hf_names) + + @dataclass class ConcatConverOp(ConverOp): dim: int = 0 @@ -173,13 +188,18 @@ def _mca_to_hf(self, weights): return StackedTensors(tensors=weights, dim=self.dim) +@dataclass class QKVConverOp(ConverOp): + hidden_size: Optional[int] = None + def __post_init__(self): super().__post_init__() assert len(self.hf_names) == 3, f"QKVConverOp only support three hf_names {self.hf_names}" assert len(self.mca_names) == 1, f"QKVConverOp only support one mca_name {self.mca_names}" def _hf_to_mca(self, weights): + if self.hidden_size is None: + self.hidden_size = self.mca_config.hidden_size q_weight, k_weight, v_weight = weights nh = self.mca_config.num_attention_heads ng = self.mca_config.num_query_groups @@ -192,22 +212,25 @@ def _hf_to_mca(self, weights): v_weight.reshape((ng, dim, -1)), ], dim=1, - ).reshape((-1, self.mca_config.hidden_size)) + ).reshape((-1, self.hidden_size)) return mca_qkv_weight def _mca_to_hf(self, weights): + if self.hidden_size is None: + self.hidden_size = self.mca_config.hidden_size qkv_weight = weights[0] ng = self.mca_config.num_query_groups nh = self.mca_config.num_attention_heads dim = self.mca_config.kv_channels qkv_weight = qkv_weight.reshape((ng, dim * (nh // ng + 2), -1)) qkv_weights = torch.split(qkv_weight, [dim * nh // ng, dim, dim], dim=1) - q_weight = qkv_weights[0].reshape((-1, self.mca_config.hidden_size)) - k_weight = qkv_weights[1].reshape((-1, self.mca_config.hidden_size)) - v_weight = qkv_weights[2].reshape((-1, self.mca_config.hidden_size)) + q_weight = qkv_weights[0].reshape((-1, self.hidden_size)) + k_weight = qkv_weights[1].reshape((-1, self.hidden_size)) + v_weight = qkv_weights[2].reshape((-1, self.hidden_size)) return [q_weight, k_weight, v_weight] +@dataclass class QKVBiasConverOp(ConverOp): def __post_init__(self): super().__post_init__() @@ -348,7 +371,10 @@ def add_mca_weight(self, name, weight): self.prefix_name_to_weight[weight_prefix] = {} self.prefix_name_to_weight[weight_prefix][original_name] = weight prefix_weights = self.prefix_name_to_weight[weight_prefix] - op = self.get_conver_op(original_name, self.mca_name_to_converter) + if ".lora_A." in original_name or ".lora_B." in original_name: + op = self.get_lora_conver_op(original_name, self.mca_name_to_converter) + else: + op = self.get_conver_op(original_name, self.mca_name_to_converter) name_to_weight = { name: prefix_weights.pop(name) for name in list(prefix_weights.keys()) @@ -371,6 +397,31 @@ def get_conver_op(self, name, pattern_to_conver_ops: Dict[str, ConverOp]): return pattern_to_conver_ops[pattern] raise ValueError(f"can not find conver op for {name} in {pattern_to_conver_ops}") + def get_lora_conver_op(self, name, pattern_to_conver_ops: Dict[str, ConverOp]): + lora_name = name[name.find(".lora"):] + name = name[:name.find(".lora")] + ".weight" + op = self.get_conver_op(name, pattern_to_conver_ops) + if isinstance(op, RenameConverOp): + op_class = RenameConverOp + kwargs = {} + elif "lora_A" in lora_name: + op_class = CopyConverOp + kwargs = {} + elif isinstance(op, StackConverOp): + op_class = StackConverOp + kwargs = {"dim": op.dim} + elif isinstance(op, QKVConverOp): + op_class = QKVConverOp + kwargs = {"hidden_size": op.mca_config.lora_rank} + else: + raise ValueError(f"can not find lora conver op for {name} in {pattern_to_conver_ops}") + return op_class( + hf_names=[hf_name.replace(".weight", lora_name) for hf_name in op.hf_names], + mca_names=[mca_name.replace(".weight", lora_name) for mca_name in op.mca_names], + mca_config=op.mca_config, + **kwargs, + ) + def hf_name_to_mca_names(self, hf_name) -> Optional[List[str]]: weight_prefix = get_weight_prefix(hf_name, self.hf_layer_prefix, moe_prefix=self.hf_moe_prefix) original_name = remove_weight_prefix(hf_name, self.hf_layer_prefix, moe_prefix=self.hf_moe_prefix) diff --git a/mcore_adapter/src/mcore_adapter/models/model_config.py b/mcore_adapter/src/mcore_adapter/models/model_config.py index eb43f4c35..ba11580f8 100644 --- a/mcore_adapter/src/mcore_adapter/models/model_config.py +++ b/mcore_adapter/src/mcore_adapter/models/model_config.py @@ -1,3 +1,4 @@ +import copy import dataclasses import enum import hashlib @@ -14,7 +15,7 @@ from transformers import AutoConfig from transformers.configuration_utils import CONFIG_NAME as HF_CONFIG_NAME -from ..constants import MCA_CONFIG_NAME, HUGGINGFACE_AUTOMAP_CACHE +from ..constants import HUGGINGFACE_AUTOMAP_CACHE, MCA_CONFIG_NAME from ..initialize import initialize_megatron from ..training_args import DistributingParallelArguments, TrainingArguments from ..utils import get_logger @@ -47,7 +48,18 @@ def post_init(self): self.__post_init__() def to_dict(self): - return dataclasses.asdict(self) + output = {} + for k, v in self.__dict__.items(): + if callable(v): + output[k] = None + elif isinstance(v, list) and callable(v[0]): + output[k] = None + elif isinstance(v, PipelineParallelLayerLayout): + output[k] = str(v) + else: + output[k] = copy.deepcopy(v) + + return output def to_json_string(self): save_dict = {} diff --git a/mcore_adapter/src/mcore_adapter/models/model_factory.py b/mcore_adapter/src/mcore_adapter/models/model_factory.py index 67de93cfb..77975b9f9 100644 --- a/mcore_adapter/src/mcore_adapter/models/model_factory.py +++ b/mcore_adapter/src/mcore_adapter/models/model_factory.py @@ -15,12 +15,16 @@ from megatron.core.transformer.module import MegatronModule from ..checkpointing import load_state_dict_from_checkpoint, save_config_and_state_dict -from ..utils import get_logger +from ..platforms import current_platform +from ..utils import get_logger, is_peft_available from .converter.convert_utils import MAX_SHARD_SIZE from .converter.model_converter import ModelConverter from .model_config import McaModelConfig from .model_utils import ModuleUtilsMixin, RMSNorm, exists_hf_config, exists_mca_config, get_thd_data_on_this_cp_rank -from ..platforms import current_platform + + +if is_peft_available(): + from peft import PeftModel if TYPE_CHECKING: @@ -43,6 +47,10 @@ def __init__(self, cls, config: "McaModelConfig", *args, **kwargs): def save_pretrained(self, save_directory: str): if len(self.models) == 1: + if is_peft_available() and isinstance(self.models[0], PeftModel): + for _, peft_config in self.models[0].peft_config.items(): + peft_config.save_pretrained(save_directory) + return self.models[0].base_model.model.save_pretrained(save_directory) return self.models[0].save_pretrained(save_directory) state_dict = {f"model{i}": model.state_dict_for_save_checkpoint() for i, model in enumerate(self.models)} return self.models[0].save_pretrained(save_directory, state_dict=state_dict) @@ -51,6 +59,8 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = Tr if len(self.models) == 1: if "model" in state_dict: state_dict = state_dict["model"] + if is_peft_available() and isinstance(self.models[0], PeftModel): + return self.models[0].base_model.model.load_state_dict(state_dict, strict=False) return self.models[0].load_state_dict(state_dict, strict=strict) all_missing_keys, all_unexpected_keys = [], [] for i, model in enumerate(self.models): @@ -203,7 +213,20 @@ def from_pretrained( def save_pretrained(self, save_directory: str, state_dict=None): os.makedirs(save_directory, exist_ok=True) - state_dict = state_dict if state_dict is not None else {"model": self.state_dict_for_save_checkpoint()} + if state_dict is None: + new_state_dict = {} + state_dict_model = self.state_dict_for_save_checkpoint() + for n, p in self.named_parameters(): + if not p.requires_grad: + continue + if n in state_dict_model: + new_state_dict[n] = state_dict_model[n] + key = n.replace('.weight', '._extra_state') + if key.endswith('._extra_state0'): + key = key.replace('._extra_state0', '._extra_state') + if key in state_dict_model: + new_state_dict[key] = state_dict_model[key] + state_dict = {"model": new_state_dict} save_config_and_state_dict(save_directory, self.config, state_dict) def get_batch_on_this_cp_rank(self, batch: Dict[str, "torch.Tensor"], dim3_keys: List[str] = ["attention_mask"]): @@ -244,6 +267,18 @@ def get_batch_on_this_cp_rank(self, batch: Dict[str, "torch.Tensor"], dim3_keys: return batch + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + if hasattr(self, "embedding"): + self._require_grads_hook = self.embedding.register_forward_hook(make_inputs_require_grads) + class McaGPTModel(GPTModel, PretrainedModel): main_input_name: str = "input_ids" diff --git a/mcore_adapter/src/mcore_adapter/models/model_utils.py b/mcore_adapter/src/mcore_adapter/models/model_utils.py index 70d07e396..c7c83817e 100644 --- a/mcore_adapter/src/mcore_adapter/models/model_utils.py +++ b/mcore_adapter/src/mcore_adapter/models/model_utils.py @@ -8,8 +8,9 @@ from megatron.core.transformer.enums import AttnBackend from ..constants import MCA_CONFIG_NAME -from ..utils import get_logger from ..platforms import current_platform +from ..utils import get_logger + if TYPE_CHECKING: from megatron.core.transformer import TransformerConfig diff --git a/mcore_adapter/src/mcore_adapter/utils.py b/mcore_adapter/src/mcore_adapter/utils.py index 0964506c3..c56fdb830 100644 --- a/mcore_adapter/src/mcore_adapter/utils.py +++ b/mcore_adapter/src/mcore_adapter/utils.py @@ -71,5 +71,5 @@ def _is_package_available(name: str) -> bool: return importlib.util.find_spec(name) is not None -def is_fla_available() -> bool: - return _is_package_available("fla") +def is_peft_available() -> bool: + return _is_package_available("peft") From 174787154d51d6bd283b7aff58faedda4f1edf14 Mon Sep 17 00:00:00 2001 From: wzy496492 Date: Wed, 5 Nov 2025 14:11:28 +0800 Subject: [PATCH 23/58] (fix): fix bugs in data fetching for face embeddings. --- roll/pipeline/diffusion/modules/wan_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roll/pipeline/diffusion/modules/wan_module.py b/roll/pipeline/diffusion/modules/wan_module.py index 7abd8a3e6..a80198a01 100644 --- a/roll/pipeline/diffusion/modules/wan_module.py +++ b/roll/pipeline/diffusion/modules/wan_module.py @@ -233,7 +233,7 @@ def forward(self, data, inputs=None): if inputs is None: inputs = self.forward_preprocess(data) - face_embeddings = data['face_embeddings'].unsqueeze(0).to(device=self.pipe.device, dtype=self.pipe.torch_dtype) + face_embeddings = inputs['face_embeddings'].to(device=self.pipe.device, dtype=self.pipe.torch_dtype) # step1: forward latents + vae decode video_decoded, kl_loss = self.pipe.training_loss(**inputs) From d353266402fc3ad44338f6880f31c08e909cf9fb Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Tue, 2 Dec 2025 16:15:17 +0800 Subject: [PATCH 24/58] (feat): add agentic chunk. --- roll/configs/base_config.py | 2 +- roll/pipeline/agentic/agentic_actor_worker.py | 7 +- roll/pipeline/agentic/agentic_config.py | 1 + roll/pipeline/agentic/agentic_pipeline.py | 4 +- roll/pipeline/agentic/utils.py | 118 ++++++++++++++++++ 5 files changed, 128 insertions(+), 4 deletions(-) diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 350f69cc5..8d9871e35 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -355,7 +355,7 @@ class PPOConfig(BaseConfig): whiten_rewards: bool = field(default=False, metadata={"help": "Whiten the rewards before compute advantages."}) whiten_advantages: bool = field(default=False, metadata={"help": "Whiten the advantage."}) advantage_clip: float = field(default=None, metadata={"help": "advantage_clip value"}) - adv_estimator: Literal["gae", "reinforce", "grpo", "gigpo", "step_reinforce"] = field( + adv_estimator: Literal["gae", "reinforce", "grpo", "gigpo", "step_reinforce", "agentic_reinforce"] = field( default="gae", metadata={"help": "advantage estimator: gae (GAE)."} ) norm_mean_type: Literal["batch", "group", "running", None] = field( diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index c1e273e3e..a4298a98a 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.pipeline.agentic.utils import compute_segment_masked_mean class ActorWorker(BaseActorWorker): @@ -25,7 +26,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) - ratio = (log_probs - old_log_probs).exp() + if self.pipeline_config.ratio_type == "segment": + raise NotImplemented(f"ratio_type: {self.pipeline_config.ratio_type} not implemented") + else: + ratio = (log_probs - old_log_probs).exp() + train_infer_ratio = (log_probs - infer_log_probs).exp() train_infer_diff = log_probs.exp() - infer_log_probs.exp() diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index f421d74cf..77f2db5c7 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -97,6 +97,7 @@ class AgenticConfig(PPOConfig): episode_reward_weight: float = field(default=1.0, metadata={"help": "Episode reward weight, used in GiGPO."}) step_reward_weight: float = field(default=1.0, metadata={"help": "Step reward weight, used in GiGPO."}) step_reward_gamma: float = field(default=0.95, metadata={"help": "Gamma parameter for step reward calculation"}) + ratio_type: Literal["token", "segment"] = field(default="token", metadata={"help": "Ratio type: token or segment"}) def __post_init__(self): super().__post_init__() diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index b85024f3b..7ea27dee7 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -17,7 +17,7 @@ from roll.models.model_providers import default_tokenizer_provider from roll.pipeline.agentic.agentic_config import AgenticConfig, EnvManagerConfig from roll.pipeline.agentic.utils import (dump_rollout_render, compute_discounted_returns, - compute_response_level_rewards, dump_rollout_trajectories, get_agentic_response_level_mask) + compute_response_level_rewards, dump_rollout_trajectories, get_agentic_response_level_mask, agentic_compute_advantage) from roll.pipeline.base_pipeline import BasePipeline from roll.utils.constants import RAY_NAMESPACE from roll.utils.functionals import ( @@ -232,7 +232,7 @@ def run(self): with Timer(name="compute_advantage", logger=None) as timer: # Is the advantage calculated globally across the batch, or within each group? - batch = compute_advantage( + batch = agentic_compute_advantage( data=batch, gamma=self.pipeline_config.gamma, lambd=self.pipeline_config.lambd, diff --git a/roll/pipeline/agentic/utils.py b/roll/pipeline/agentic/utils.py index d7b86f4ee..13a383d32 100644 --- a/roll/pipeline/agentic/utils.py +++ b/roll/pipeline/agentic/utils.py @@ -18,6 +18,7 @@ from roll.pipeline.agentic.agentic_config import AgenticConfig, RewardNormalizationConfig from roll.pipeline.rlvr.utils import DUMPING_FUNC from roll.utils.logging import get_logger +from roll.utils.functionals import masked_whiten, compute_gae_advantage_return, compute_clip_fraction, compute_reinforce_return logger = get_logger() @@ -263,3 +264,120 @@ def dump_rollout_trajectories(path, global_step, data: DataProto): p = multiprocessing.Process(target=func, args=(path, write_data, columns_config), daemon=False) p.start() +def compute_agentic_reinforce_return(token_level_rewards: torch.Tensor, gamma: torch.Tensor, lambd: torch.Tensor, mask: Optional[torch.Tensor] = None): + """ + 计算 REINFORCE 的 return,支持按 mask 分段 discount 衰减。 + 每段内所有位置获得相同的折扣累积值(从该段最后位置开始累积)。 + + Args: + token_level_rewards: [batch_size, seq_len] token 级别的奖励 + gamma: discount factor + lambd: lambda 参数(当前未使用,保留以兼容接口) + mask: [batch_size, seq_len] mask,1表示有效位置,0表示无效位置。如果为None,则对所有位置计算 + + Returns: + advantages: [batch_size, seq_len] advantages + returns: [batch_size, seq_len] returns + """ + with torch.no_grad(): + batch_size, gen_len = token_level_rewards.shape + device = token_level_rewards.device + returns = torch.zeros_like(token_level_rewards, dtype=torch.float32) + + # 如果没有提供 mask,则对所有位置计算(向后兼容) + if mask is None: + mask = torch.ones_like(token_level_rewards) + + # 确保 gamma 是标量 + gamma_val = gamma.item() if torch.is_tensor(gamma) else gamma + + # 对每个样本分别处理 + for b in range(batch_size): + sample_mask = mask[b] # [seq_len] + sample_rewards = token_level_rewards[b] # [seq_len] + + # 找到所有连续的1的段 + # 使用 diff 找到边界:1->0 和 0->1 的位置 + diff = torch.diff(sample_mask.float(), prepend=torch.tensor([0.0], device=device)) + + # 找到段的开始位置(0->1,diff==1) + segment_starts = torch.where(diff == 1)[0] + + # 找到段的结束位置(1->0,diff==-1) + segment_ends = torch.where(diff == -1)[0] + + # 如果最后一个位置是1,需要添加结束位置 + if len(sample_mask) > 0 and sample_mask[-1] == 1: + segment_ends = torch.cat([segment_ends, torch.tensor([gen_len], device=device)]) + + # 计算该段从最后位置开始的累积折扣奖励 + cumulative_return = 0.0 + # 对每段分别计算 discounted return + for start, end in zip(segment_starts.flip(-1), segment_ends.flip(-1)): + start_idx = start.item() + end_idx = end.item() + segment_len = end_idx - start_idx + + cumulative_return = sample_rewards[end_idx-1].item() + gamma_val * cumulative_return + + # 该段内所有位置都设置为这个累积值 + returns[b, start_idx:end_idx] = cumulative_return + + advantages = returns + + return advantages, returns + +@torch.no_grad() +def agentic_compute_advantage( + data: "DataProto", + gamma, + lambd, + adv_estimator, + advantage_clip=None, + whiten_advantages=False, + whiten_rewards=False, + response_mask=None, +): + if response_mask is None: + response_mask = data.batch["response_mask"][:, 1:] + if response_mask.sum() == 0: + whiten_rewards = False + whiten_advantages = False + logger.info("Warning: domain final_response_mask.sum() == 0! All masked_whiten will be skipped.") + + token_level_rewards = data.batch["token_level_rewards"].float() + if whiten_rewards: + token_level_rewards = masked_whiten(values=token_level_rewards, mask=response_mask) + token_level_rewards = token_level_rewards * response_mask + data.batch["token_level_rewards"] = token_level_rewards + if adv_estimator == "gae": + values = data.batch["values"].float() + data.batch["values"] = values * response_mask + advantages, returns = compute_gae_advantage_return( + token_level_rewards=token_level_rewards, values=values, gamma=gamma, lambd=lambd + ) + elif adv_estimator in ["reinforce", "grpo", "gigpo", "step_reinforce"]: + advantages, returns = compute_reinforce_return( + token_level_rewards=token_level_rewards, gamma=gamma, lambd=lambd + ) + elif adv_estimator in ["agentic_reinforce"]: + advantages, returns = compute_agentic_reinforce_return( + token_level_rewards=token_level_rewards, gamma=gamma, lambd=lambd, mask=response_mask + ) + else: + raise NotImplementedError + + data.batch["raw_advantages"] = advantages + if whiten_advantages: + # TODO whiten过程中是否要考虑response的长度? + advantages = masked_whiten(values=advantages, mask=response_mask) + advantages = advantages * response_mask + + if advantage_clip is not None: + adv_clip_frac = compute_clip_fraction(values=advantages, clip_min=-advantage_clip, clip_max=advantage_clip) + data.meta_info["metrics"] = {"critic/advantage_clip_frac": adv_clip_frac} + advantages = torch.clamp(advantages, min=-advantage_clip, max=advantage_clip) + + data.batch["advantages"] = advantages + data.batch["returns"] = returns + return data \ No newline at end of file From 2a73b0ac174f38880b7a1901cec89a994ae8859b Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:13:52 +0800 Subject: [PATCH 25/58] (feat): add sglang 0.4.6.post5. --- roll/third_party/sglang/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roll/third_party/sglang/__init__.py b/roll/third_party/sglang/__init__.py index d35671f7c..19c13ee31 100644 --- a/roll/third_party/sglang/__init__.py +++ b/roll/third_party/sglang/__init__.py @@ -4,7 +4,7 @@ if sgl.__version__ == '0.4.6.post4': from roll.third_party.sglang import v046post4_patch patch = v046post4_patch -elif sgl.__version__ == '0.4.6.post1': +elif sgl.__version__ == '0.4.6.post1' or sgl.__version__ == '0.4.6.post5': from roll.third_party.sglang import v046post4_patch patch = v046post4_patch elif sgl.__version__ == '0.4.10.post2': From e9ba1319d3ba7f8581e12db299038ce0b00993de Mon Sep 17 00:00:00 2001 From: "xuehuanran.xhr" Date: Fri, 7 Nov 2025 16:54:37 +0800 Subject: [PATCH 26/58] (feat): support offload nccl to save gpu memory. --- roll/configs/worker_config.py | 4 + roll/distributed/executor/worker.py | 3 + roll/distributed/strategy/strategy.py | 1 + roll/pipeline/base_worker.py | 7 +- roll/utils/context_managers.py | 20 +- roll/utils/offload_nccl.py | 268 ++++++++++++++++++++++++++ 6 files changed, 300 insertions(+), 3 deletions(-) create mode 100644 roll/utils/offload_nccl.py diff --git a/roll/configs/worker_config.py b/roll/configs/worker_config.py index 02be29a78..eae29bad0 100644 --- a/roll/configs/worker_config.py +++ b/roll/configs/worker_config.py @@ -144,6 +144,10 @@ class WorkerConfig: metadata={"help": "The value to round up to when truncating the sequence length." "Note: This config must be set when using dynamic batching."} ) + offload_nccl: bool = field( + default=False, + metadata={"help": "Whether offload nccl buffer to save gpu memory."} + ) def __post_init__(self): diff --git a/roll/distributed/executor/worker.py b/roll/distributed/executor/worker.py index 072411b62..4f11f0cb5 100644 --- a/roll/distributed/executor/worker.py +++ b/roll/distributed/executor/worker.py @@ -17,6 +17,7 @@ from roll.utils.logging import get_logger from roll.utils.network_utils import collect_free_port, get_node_ip from roll.utils.offload_states import OffloadStateType +from roll.utils.offload_nccl import monkey_patch_torch_dist from roll.platforms import current_platform @@ -42,6 +43,8 @@ def is_pipeline_last_stage(self): class Worker: def __init__(self, worker_config: WorkerConfig): + if worker_config.offload_nccl: + monkey_patch_torch_dist() self.worker_config = worker_config self.pipeline_config = None self.worker_name = os.environ.get("WORKER_NAME", None) diff --git a/roll/distributed/strategy/strategy.py b/roll/distributed/strategy/strategy.py index bbc92a230..7c49c5c80 100644 --- a/roll/distributed/strategy/strategy.py +++ b/roll/distributed/strategy/strategy.py @@ -27,6 +27,7 @@ def __init__(self, worker: "Worker"): self.worker_config = self.worker.worker_config self.thread_executor: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor(max_workers=5) self.model_update_comm_plan = {} + self.offload_nccl = self.worker_config.offload_nccl def initialize(self, *args, **kwargs): raise NotImplementedError diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 0987cfefa..8a4aac5bb 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -26,6 +26,7 @@ GenerateRequestType, agg_loss, ) +from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -218,8 +219,7 @@ def stop_server(self, data: DataProto = None): def compute_log_probs(self, data: DataProto): """ return DataProto.from_dict(tensors={'log_probs': output}) - """ - data = self.strategy.get_data_input(data) + """ global_step = data.meta_info.get("global_step", 0) is_offload_states = data.meta_info.get("is_offload_states", True) metrics = {} @@ -230,6 +230,7 @@ def compute_log_probs(self, data: DataProto): is_offload_states=is_offload_states, load_kwargs={"include": [OffloadStateType.model_params]}, ): + data = self.strategy.get_data_input(data) data = data.to(current_platform.device_type) data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size with torch.no_grad(): @@ -334,6 +335,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): + if self.worker_config.offload_nccl: + reload_process_groups() with Timer("do_checkpoint") as total_timer: ckpt_id = f"checkpoint-{global_step}" diff --git a/roll/utils/context_managers.py b/roll/utils/context_managers.py index 7014e46e2..7fa574f17 100644 --- a/roll/utils/context_managers.py +++ b/roll/utils/context_managers.py @@ -14,6 +14,7 @@ from roll.platforms import current_platform from roll.utils.offload_states import OffloadStateType +from roll.utils.offload_nccl import reload_process_groups, destroy_process_groups from roll.utils.logging import get_logger, is_roll_debug_mode @@ -31,10 +32,11 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, rank: int = 0 memory_allocated = current_platform.memory_allocated() / 1024**3 memory_reserved = current_platform.memory_reserved() / 1024**3 memory_reserved_max = current_platform.max_memory_reserved() / 1024**3 + memory_device_used = current_platform.device_memory_used() / 1024**3 rss = cpu_memory_info().rss / 1024**3 message = ( f"{head}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}, " - f"memory max reserved (GB): {memory_reserved_max}, rss (GB): {rss}" + f"memory max reserved (GB): {memory_reserved_max}, rss (GB): {rss} memory device used (GB): {memory_device_used}" ) logger.info(msg=message) @@ -80,6 +82,14 @@ def local_profiler(): yield +@contextmanager +def gpu_memory_offload_profiler(metrics, metric_infix, stage): + memory_start_offload = current_platform.device_memory_used() / 1024**3 + yield + memory_end_offload = current_platform.device_memory_used() / 1024**3 + metrics[f"memory/{metric_infix}/{stage}"] = abs(memory_end_offload - memory_start_offload) + + def get_load_exclude_kwargs(load_kwargs): assert load_kwargs.get("include", None) is not None exclude_kwargs = copy.deepcopy(load_kwargs) @@ -158,6 +168,10 @@ def state_offload_manger(strategy, metrics: Dict, metric_infix: str, is_offload_ strategy.load_states(**load_kwargs) if load_kwargs.get("include", None) is not None: strategy.offload_states(**get_load_exclude_kwargs(load_kwargs)) + if strategy.offload_nccl: + with Timer(f"{metric_infix}_reload") as reload_pg_timer, gpu_memory_offload_profiler(metrics, metric_infix, "reload_nccl"): + reload_process_groups() + metrics[f"time/{metric_infix}/reload_nccl"] = reload_pg_timer.last log_gpu_memory_usage(head=f"{metric_infix}_start_onload", logger=logger, rank=None) metrics.update(_get_gpu_memory_metrics(metric_infix, "start/onload")) @@ -173,6 +187,10 @@ def state_offload_manger(strategy, metrics: Dict, metric_infix: str, is_offload_ if is_offload_states: current_platform.clear_cublas_workspaces() strategy.offload_states() + if strategy.offload_nccl: + with Timer(f"{metric_infix}_destroy") as destroy_pg_timer, gpu_memory_offload_profiler(metrics, metric_infix, "offload_nccl"): + destroy_process_groups() + metrics[f"time/{metric_infix}/offload_nccl"] = destroy_pg_timer.last log_gpu_memory_usage(head=f"{metric_infix}_end_offload", logger=logger, rank=None) metrics.update(_get_gpu_memory_metrics(metric_infix, "end/offload")) diff --git a/roll/utils/offload_nccl.py b/roll/utils/offload_nccl.py new file mode 100644 index 000000000..143e8a139 --- /dev/null +++ b/roll/utils/offload_nccl.py @@ -0,0 +1,268 @@ +""" +ref: https://github.com/THUDM/slime/blob/main/docker/patch/latest/megatron.patch +""" +import os +import torch +import torch.distributed as dist + +from roll.utils.logging import get_logger + + +logger = get_logger() + +old_new_group_dict = {} + + +def monkey_patch_torch_dist(): + pid = os.getpid() + if pid in old_new_group_dict: + assert dist.old_new_group == old_new_group_dict[pid] + return + + logger.info("Applying monkey patch to torch.distributed") + + old_new_group = dist.new_group + old_new_group_dict[pid] = old_new_group + setattr(dist, "old_new_group", old_new_group) + + def new_group(*args, **kwargs): + group = old_new_group(*args, **kwargs) + # skip non-nccl group. + if (len(args) >= 3 and args[2] == "gloo") or ("backend" in kwargs and kwargs["backend"] == "gloo"): + return group + + # Get ranks from arguments + if len(args) >= 1 and args[0] is not None: + ranks = args[0] + elif "ranks" in kwargs and kwargs["ranks"] is not None: + ranks = kwargs["ranks"] + else: + # If no ranks specified, use all ranks in world + ranks = list(range(dist.get_world_size())) + + if len(ranks) == 1: + return group + + group = ReloadableProcessGroup(group, ranks) + return group + + dist.new_group = new_group + + def get_new_function(func): + def new_function(*args, **kwargs): + args = (arg.group if isinstance(arg, ReloadableProcessGroup) else arg for arg in args) + kwargs = {k: (v.group if isinstance(v, ReloadableProcessGroup) else v) for k, v in kwargs.items()} + return func(*args, **kwargs) + + return new_function + + dist.get_rank = get_new_function(dist.get_rank) + dist.get_world_size = get_new_function(dist.get_world_size) + dist.get_backend = get_new_function(dist.get_backend) + dist.get_global_rank = get_new_function(dist.get_global_rank) + dist.get_group_rank = get_new_function(dist.get_group_rank) + dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks) + + dist.all_reduce = get_new_function(dist.all_reduce) + dist.all_gather = get_new_function(dist.all_gather) + dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor) + dist.all_gather_object = get_new_function(dist.all_gather_object) + dist.all_to_all = get_new_function(dist.all_to_all) + dist.all_to_all_single = get_new_function(dist.all_to_all_single) + dist.broadcast = get_new_function(dist.broadcast) + dist.broadcast_object_list = get_new_function(dist.broadcast_object_list) + dist.reduce = get_new_function(dist.reduce) + dist.reduce_scatter = get_new_function(dist.reduce_scatter) + dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor) + dist.scatter = get_new_function(dist.scatter) + dist.gather = get_new_function(dist.gather) + dist.barrier = get_new_function(dist.barrier) + dist.send = get_new_function(dist.send) + dist.recv = get_new_function(dist.recv) + dist._coalescing_manager = get_new_function(dist._coalescing_manager) + + # p2p + old_isend = dist.isend + old_irecv = dist.irecv + + dist.isend = get_new_function(dist.isend) + dist.irecv = get_new_function(dist.irecv) + + def get_new_p2pop_function(func): + def new_function(*args, **kwargs): + def convert(arg): + if isinstance(arg, ReloadableProcessGroup): + return arg.group + elif arg == dist.isend: + arg = old_isend + elif arg == dist.irecv: + arg = old_irecv + return arg + + args = (convert(arg) for arg in args) + kwargs = {k: convert(v) for k, v in kwargs.items()} + return func(*args, **kwargs) + + return new_function + + dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__) + dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__) + + +class ReloadableProcessGroup(torch.distributed.ProcessGroup): + GROUPS = {} + + def __init__(self, group, ranks): + super().__init__( + rank=dist.get_rank(group), + size=dist.get_world_size(group), + ) + self.group = group + self.group_info = { + "ranks": ranks, + } + pid = os.getpid() + if pid not in ReloadableProcessGroup.GROUPS: + ReloadableProcessGroup.GROUPS[pid] = [] + ReloadableProcessGroup.GROUPS[pid].append(self) + + def __getattr__(self, name): + return getattr(self.group, name) + + @staticmethod + def destroy_process_groups(): + pid = os.getpid() + if pid in ReloadableProcessGroup.GROUPS: + logger.info(f"Destroying {len(ReloadableProcessGroup.GROUPS[pid])} process groups in pid {pid}") + for reloadable_group in ReloadableProcessGroup.GROUPS[pid]: + if reloadable_group.group is None: + continue + dist.destroy_process_group(reloadable_group.group) + + del reloadable_group.group + reloadable_group.group = None + + @staticmethod + def reload_process_groups(): + pid = os.getpid() + if pid in ReloadableProcessGroup.GROUPS: + logger.info(f"Reloading {len(ReloadableProcessGroup.GROUPS[pid])} process groups in pid {pid}") + old_new_group = old_new_group_dict[pid] + for reloadable_group in ReloadableProcessGroup.GROUPS[pid]: + if reloadable_group.group is not None: + continue + group = old_new_group(ranks=reloadable_group.group_info["ranks"], backend="nccl") + reloadable_group.group = group + + def rank(self) -> int: + return self.group.rank() + + def size(self) -> int: + return self.group.size() + + def name(self) -> str: + return self.group.name() + + def shutdown(self) -> None: + if self.group is not None: + self.group.shutdown() + + def abort(self) -> None: + if self.group is not None: + self.group.abort() + + def _fwd(self, method, *args, **kwargs): + inner = self.group + if inner is None: + raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.") + return getattr(inner, method)(*args, **kwargs) + + def barrier(self, *a, **kw): + return self._fwd("barrier", *a, **kw) + + def broadcast(self, *a, **kw): + return self._fwd("broadcast", *a, **kw) + + def allreduce(self, *a, **kw): + return self._fwd("allreduce", *a, **kw) + + def allreduce_coalesced(self, *a, **kw): + return self._fwd("allreduce_coalesced", *a, **kw) + + def reduce(self, *a, **kw): + return self._fwd("reduce", *a, **kw) + + def allgather(self, *a, **kw): + return self._fwd("allgather", *a, **kw) + + def _allgather_base(self, *a, **kw): + return self._fwd("_allgather_base", *a, **kw) + + def allgather_coalesced(self, *a, **kw): + return self._fwd("allgather_coalesced", *a, **kw) + + def allgather_into_tensor_coalesced(self, *a, **kw): + return self._fwd("allgather_into_tensor_coalesced", *a, **kw) + + def gather(self, *a, **kw): + return self._fwd("gather", *a, **kw) + + def scatter(self, *a, **kw): + return self._fwd("scatter", *a, **kw) + + def reduce_scatter(self, *a, **kw): + return self._fwd("reduce_scatter", *a, **kw) + + def _reduce_scatter_base(self, *a, **kw): + return self._fwd("_reduce_scatter_base", *a, **kw) + + def reduce_scatter_tensor_coalesced(self, *a, **kw): + return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw) + + def alltoall_base(self, *a, **kw): + return self._fwd("alltoall_base", *a, **kw) + + def alltoall(self, *a, **kw): + return self._fwd("alltoall", *a, **kw) + + def send(self, *a, **kw): + return self._fwd("send", *a, **kw) + + def recv(self, *a, **kw): + return self._fwd("recv", *a, **kw) + + def recv_anysource(self, *a, **kw): + return self._fwd("recv_anysource", *a, **kw) + + def _start_coalescing(self, *a, **kw): + return self._fwd("_start_coalescing", *a, **kw) + + def _end_coalescing(self, *a, **kw): + return self._fwd("_end_coalescing", *a, **kw) + + def _get_backend_name(self): + return self._fwd("_get_backend_name") + + def _get_backend(self, *a, **kw): + return self._fwd("_get_backend", *a, **kw) + + def _set_default_backend(self, *a, **kw): + return self._fwd("_set_default_backend", *a, **kw) + + @property + def bound_device_id(self): + return self.group.bound_device_id + + @bound_device_id.setter + def bound_device_id(self, dev): + self.group.bound_device_id = dev + + +def destroy_process_groups(): + """Destroy all reloadable process groups.""" + ReloadableProcessGroup.destroy_process_groups() + + +def reload_process_groups(): + """Reload all reloadable process groups.""" + ReloadableProcessGroup.reload_process_groups() From 796603c1cebb885e069ff6fcafc95d2b3d4e8430 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:14:53 +0800 Subject: [PATCH 27/58] (feat): support pytorch280 docker. --- requirements_torch280_sglang.txt | 5 ++++- requirements_torch280_vllm.txt | 11 ++++++++++- roll/platforms/cuda.py | 1 + 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/requirements_torch280_sglang.txt b/requirements_torch280_sglang.txt index 3817cce34..f42a343f4 100644 --- a/requirements_torch280_sglang.txt +++ b/requirements_torch280_sglang.txt @@ -1,5 +1,8 @@ -r requirements_common.txt -deepspeed==0.16.4 +torch==2.8.0.* +torchvision==0.23.0.* +torchaudio==2.8.0.* +deepspeed==0.16.4 sglang[srt,torch-memory-saver]==0.5.2 \ No newline at end of file diff --git a/requirements_torch280_vllm.txt b/requirements_torch280_vllm.txt index f9ac68324..0eec024f2 100644 --- a/requirements_torch280_vllm.txt +++ b/requirements_torch280_vllm.txt @@ -1,4 +1,13 @@ -r requirements_common.txt +torch==2.8.0.* +torchvision==0.23.0.* +torchaudio==2.8.0.* + transformers==4.57.0 -vllm==0.11.0 +deepspeed==0.16.4 + +flash-attn + +#todo upgrade docker image +vllm==0.10.2 diff --git a/roll/platforms/cuda.py b/roll/platforms/cuda.py index fcf2c289a..5b46ba925 100644 --- a/roll/platforms/cuda.py +++ b/roll/platforms/cuda.py @@ -65,6 +65,7 @@ def get_vllm_run_time_env_vars(cls, gpu_rank: str) -> dict: env_vars = { "PYTORCH_CUDA_ALLOC_CONF": "", "VLLM_ALLOW_INSECURE_SERIALIZATION":"1", + "VLLM_ALLREDUCE_USE_SYMM_MEM": "0", # vllm 0.11.0 bug: https://github.com/vllm-project/vllm/issues/24694 "CUDA_VISIBLE_DEVICES": f"{gpu_rank}", "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", } From 6f5b8f75ce5fc15dd643499a5e932c362fce59cf Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Mon, 10 Nov 2025 13:51:56 +0800 Subject: [PATCH 28/58] (fix): fix vllm 0110 model_config. --- roll/third_party/vllm/vllm_0_11_0/llm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/roll/third_party/vllm/vllm_0_11_0/llm.py b/roll/third_party/vllm/vllm_0_11_0/llm.py index dc75e1340..8734db1a8 100644 --- a/roll/third_party/vllm/vllm_0_11_0/llm.py +++ b/roll/third_party/vllm/vllm_0_11_0/llm.py @@ -197,10 +197,9 @@ def __init__( self.supported_tasks = supported_tasks # Load the Input/Output processor plugin if any - io_processor_plugin = self.llm_engine.model_config.io_processor_plugin - self.io_processor = get_io_processor(self.llm_engine.vllm_config, - io_processor_plugin) - + self.model_config = self.llm_engine.model_config + self.processor = self.llm_engine.processor + self.io_processor = self.llm_engine.io_processor def load_states(self): self.collective_rpc(method="load_states") From 29301e712004fa80615181ba831daa12827b6ccc Mon Sep 17 00:00:00 2001 From: "weixun.wwx" Date: Tue, 11 Nov 2025 11:29:48 +0800 Subject: [PATCH 29/58] (refactor): refactor agentic norm. --- roll/pipeline/agentic/agentic_config.py | 123 +++++++++++-- roll/pipeline/agentic/agentic_pipeline.py | 3 +- roll/pipeline/agentic/utils.py | 207 ++++++++++------------ tests/utils/test_score_norm_fn.py | 35 ---- 4 files changed, 204 insertions(+), 164 deletions(-) delete mode 100644 tests/utils/test_score_norm_fn.py diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index 77f2db5c7..0c2d1c3b8 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -16,10 +16,88 @@ logger = get_logger() +def _resolve_reward_norm_defaults(method: str, grouping: str) -> Dict[str, Optional[str]]: + normalized_group = (grouping or "").lower() + if normalized_group == "batch": + mean_type = "batch" + std_type = "batch" + else: + if normalized_group not in ["state", "inductive"]: + logger.warning( + f"`RewardNormalizationConfig.grouping` 的取值 {normalized_group} 不在 ['batch', 'state', 'inductive'] 中,mean和std的统计范围设置为 'group', 然后再结合method来做选择norm的方式" + ) + mean_type = "group" + std_type = "group" + + if method == "identity": + return {"norm_mean_type": None, "norm_std_type": None} + elif method == "mean": + return {"norm_mean_type": mean_type, "norm_std_type": None} + elif method in {"mean_std", "asym_clip"}: + return {"norm_mean_type": mean_type, "norm_std_type": std_type} + + return {"norm_mean_type": None, "norm_std_type": None} + + @dataclass class RewardNormalizationConfig: grouping: str = field(default="state", metadata={"help": "state / batch / inductive"}) - method: str = field(default="identity", metadata={"help": "asym_clip / identity / mean_std"}) + method: str = field( + default="identity", + metadata={ + "help": "已废弃字段: 取值仅用于推导 norm_mean_type / norm_std_type;请优先直接配置新字段", + "deprecated": True, + }, + ) + norm_mean_type: Optional[Literal["batch", "group"]] = field( + default=None, + metadata={ + "help": "Mean type for reward normalization: 'batch' (normalize across batch), 'group' (normalize within groups), None (without subtracting mean)" + }, + ) + norm_std_type: Optional[Literal["batch", "group"]] = field( + default=None, + metadata={ + "help": "Std type for reward normalization: 'batch' (normalize across batch), 'group' (normalize within groups), None (without dividing by std)" + }, + ) + + def __post_init__(self): + + if self.method not in {"identity", "mean", "mean_std", "asym_clip"}: + logger.warning( + f"`RewardNormalizationConfig.method` 的取值 {self.method!r} 已废弃且无效,将回退为 'identity'。" + ) + self.method = "identity" + + logger.warning( + "`RewardNormalizationConfig.method` 已废弃,后续版本将移除;显式配置 `norm_mean_type` / `norm_std_type` 优先级最高," + " `method` 仅在`norm_mean_type` / `norm_std_type`字段为空时参与兜底。" + ) + + defaults = _resolve_reward_norm_defaults(self.method, self.grouping) + if self.norm_mean_type is None: + logger.info( + "`norm_mean_type` 未显式配置,将依据 method=%s 与 grouping=%s 推导为 %s。", + self.method, + self.grouping, + defaults["norm_mean_type"], + ) + self.norm_mean_type = defaults["norm_mean_type"] + if self.norm_std_type is None: + logger.info( + "`norm_std_type` 未显式配置,将依据 method=%s 与 grouping=%s 推导为 %s。", + self.method, + self.grouping, + defaults["norm_std_type"], + ) + self.norm_std_type = defaults["norm_std_type"] + logger.info( + "`RewardNormalizationConfig` 将采用 norm_mean_type=%s, norm_std_type=%s。", + self.norm_mean_type, + self.norm_std_type, + ) + @dataclass class LLMProxyConfig: @@ -51,8 +129,7 @@ class EnvManagerConfig(WorkerConfig): metadata={"help": "The class of the worker."}, ) max_env_num_per_worker: int = field( - default=0, - metadata={"help": "The maximum number of envs per worker. one env per thread."} + default=0, metadata={"help": "The maximum number of envs per worker. one env per thread."} ) group_filter_cls: str = field( default="roll.pipeline.agentic.agentic_pipeline.GroupFilter", @@ -68,7 +145,9 @@ def __post_init__(self): logger.warning("all env in one worker by default, you can set max_env_num_per_worker to scale env.") logger.info(f"max_env_num_per_worker: {self.max_env_num_per_worker}") - self.world_size = (self.num_env_groups * self.final_group_size + self.max_env_num_per_worker - 1) // self.max_env_num_per_worker + self.world_size = ( + self.num_env_groups * self.final_group_size + self.max_env_num_per_worker - 1 + ) // self.max_env_num_per_worker self.env_configs: Optional[Dict[int, Dict[int, Dict]]] = None """ worker_rank: @@ -80,6 +159,7 @@ def __post_init__(self): def final_group_size(self): return self.group_size + self.group_size_redundancy + @dataclass class AgenticConfig(PPOConfig): # agentic related @@ -147,7 +227,9 @@ def __post_init__(self): if self.val_batch_size < 0: self.val_env_manager.max_traj_per_env = sys.maxsize else: - assert self.val_batch_size % val_env_num == 0, f"val_batch_size {self.val_batch_size} must be divisible by val_env_num {val_env_num}, equal best" + assert ( + self.val_batch_size % val_env_num == 0 + ), f"val_batch_size {self.val_batch_size} must be divisible by val_env_num {val_env_num}, equal best" traj_per_env = (self.val_batch_size + val_env_num - 1) // val_env_num if self.val_env_manager.max_traj_per_env < 0: @@ -164,7 +246,8 @@ def make_env_configs(self, env_manager_config: EnvManagerConfig): max_env_num_per_worker = env_manager_config.max_env_num_per_worker for tag, n_group in zip(env_manager_config.tags, env_manager_config.num_groups_partition): for env_id in range( - done_groups * env_manager_config.final_group_size, (done_groups + n_group) * env_manager_config.final_group_size + done_groups * env_manager_config.final_group_size, + (done_groups + n_group) * env_manager_config.final_group_size, ): cfg_template = self.custom_envs[tag] env_class = cfg_template.env_type @@ -178,22 +261,28 @@ def make_env_configs(self, env_manager_config: EnvManagerConfig): env_config = {**cfg_template.env_config} if group_id not in group_seeds: - group_seeds[group_id] = random.randint(0, 2**31-1) + group_seeds[group_id] = random.randint(0, 2**31 - 1) entry = {} entry.update(cfg_template) entry.pop("env_config", None) - entry.update({ - "tag": tag, - "group_id": group_id, - "env_id": env_id, - "config": env_config, - "env_class": env_class, - "env_manager_cls": cfg_template.get("env_manager_cls", "roll.pipeline.agentic.env_manager.traj_env_manager.TrajEnvManager"), - "group_seed": group_seeds[group_id], - }) + entry.update( + { + "tag": tag, + "group_id": group_id, + "env_id": env_id, + "config": env_config, + "env_class": env_class, + "env_manager_cls": cfg_template.get( + "env_manager_cls", "roll.pipeline.agentic.env_manager.traj_env_manager.TrajEnvManager" + ), + "group_seed": group_seeds[group_id], + } + ) worker_rank = env_id // max_env_num_per_worker env_configs[worker_rank][env_id] = DictConfig(entry) - logger.info(f"[ENV CONFIG] tag: {tag}, group_id: {group_id}, group_seeds: {group_seeds[group_id]}, env_id: {env_id}") + logger.info( + f"[ENV CONFIG] tag: {tag}, group_id: {group_id}, group_seeds: {group_seeds[group_id]}, env_id: {env_id}" + ) done_groups += n_group assert done_groups == env_manager_config.num_env_groups env_manager_config.env_configs = env_configs diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 7ea27dee7..726affd11 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -219,8 +219,9 @@ def run(self): # Rewards need to be processed after grouping # We can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to compute rewards / advantages # The compute_response_level_rewards function injects a response_level_rewards key into batch.batch. - batch = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) + batch, reward_metrics = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics.update(reward_metrics) metrics["time/cal_norm_rewards"] = timer.last with Timer(name="cal_token_reward", logger=None) as timer: diff --git a/roll/pipeline/agentic/utils.py b/roll/pipeline/agentic/utils.py index 13a383d32..6cd46ba05 100644 --- a/roll/pipeline/agentic/utils.py +++ b/roll/pipeline/agentic/utils.py @@ -18,7 +18,12 @@ from roll.pipeline.agentic.agentic_config import AgenticConfig, RewardNormalizationConfig from roll.pipeline.rlvr.utils import DUMPING_FUNC from roll.utils.logging import get_logger -from roll.utils.functionals import masked_whiten, compute_gae_advantage_return, compute_clip_fraction, compute_reinforce_return +from roll.utils.functionals import ( + masked_whiten, + compute_gae_advantage_return, + compute_clip_fraction, + compute_reinforce_return, +) logger = get_logger() @@ -50,29 +55,6 @@ def dump_rollout_render(save_dir, step, frames: List[List], env_ids: List, tags: logger.error(f"dump rollout render failed: {e}") logger.info(f"dump_rollout_render_cost: {timer.last}") -@torch.no_grad() -def get_score_normalize_fn(rn_cfg) -> Callable: - grouping, method = rn_cfg.grouping, rn_cfg.method - if method == "mean_std": - norm_func = lambda x: ( - (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-6) - if x.numel() > 1 and x.std(dim=-1, keepdim=True).abs().max() > 1e-6 - else torch.zeros_like(x) - ) # stable to bf16 than x.std() - elif method == "mean": - norm_func = lambda x: (x - x.mean(dim=-1, keepdim=True)) - elif method == "asym_clip": - norm_func = lambda x: ( - (x - x.mean(dim=-1, keepdim=True)) / (x.std(dim=-1, keepdim=True) + 1e-6) - if x.numel() > 1 and x.std(dim=-1, keepdim=True).abs().max() > 1e-6 - else torch.zeros_like(x) - ).clamp(min=-1, max=3) - elif method == "identity": - norm_func = lambda x: x - else: - raise ValueError(f"Invalid normalization method: {method}") - - return norm_func @torch.no_grad() def compute_discounted_returns(batch: DataProto, adv_estimator, gamma=1.0) -> DataProto: @@ -88,10 +70,10 @@ def compute_discounted_returns(batch: DataProto, adv_estimator, gamma=1.0) -> Da DataProto: Updated batch where each trajectory contains an extra tensor key `"step_rewards"` holding the computed discounted returns. """ - if adv_estimator in ["gigpo", "step_reinforce" ]: + if adv_estimator in ["gigpo", "step_reinforce"]: batch.batch["sample_order_placeholder"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device) batch_group_by_traj: Dict[str, DataProto] = batch.group_by(keys="traj_id") - for traj_id, traj_batch in batch_group_by_traj.items(): + for traj_id, traj_batch in batch_group_by_traj.items(): indices: Tensor = torch.argsort(torch.from_numpy(traj_batch.non_tensor_batch["step"].astype(np.int64))) traj_batch.reorder(indices) @@ -111,23 +93,65 @@ def compute_discounted_returns(batch: DataProto, adv_estimator, gamma=1.0) -> Da else: return batch -def grouped_reward_norm(batch: "DataProto", reward_normalization: RewardNormalizationConfig) -> torch.Tensor: + +# TODO: 这里的功能性和rlvr比较接近,但因为后续agentic会有潜在的修改需求,所以就先拎出来 +@torch.no_grad() +def agentic_reward_norm(batch: "DataProto", reward_normalization: RewardNormalizationConfig) -> torch.Tensor: batch.batch["sample_order_placeholder"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device) grouping = reward_normalization.grouping + norm_mean_type = reward_normalization.norm_mean_type + norm_std_type = reward_normalization.norm_std_type + + all_scores = batch.batch["scores"].float() + batch_mean = None + batch_std = None + if norm_mean_type == "batch": + batch_mean = all_scores.mean() + if norm_std_type == "batch": + batch_std = all_scores.std() + + batch_list = [] batch_grouped: Dict[str, DataProto] = {"default": batch} if grouping != "batch": batch_grouped = batch.group_by(keys=grouping) - batch_list = [] for group_name, group_batch in batch_grouped.items(): - score_norm_fn = get_score_normalize_fn(rn_cfg=reward_normalization) - normalized_acc_scores = score_norm_fn(group_batch.batch["scores"]) - group_batch.batch["grouped_rewards"] = normalized_acc_scores + scores = group_batch.batch["scores"] + original_dtype = scores.dtype + scores_float = scores.float() + + if norm_mean_type == "batch": + reward_mean = batch_mean + elif norm_mean_type == "group": + reward_mean = scores_float.mean() + else: + reward_mean = 0.0 + + if norm_std_type == "batch": + reward_std = batch_std + elif norm_std_type == "group": + reward_std = scores_float.std() + else: + reward_std = None + + if reward_std is not None: + # 处理单个元素或标准差为0的情况,避免除以0 + if scores_float.numel() > 1 and reward_std.abs() > 1e-6: + normalized_scores = (scores_float - reward_mean) / (reward_std + 1e-6) + else: + normalized_scores = torch.zeros_like(scores_float) + else: + normalized_scores = scores_float - reward_mean + + normalized_scores = normalized_scores.to(dtype=original_dtype) + group_batch.batch["grouped_rewards"] = normalized_scores batch_list.append(group_batch) + batch = DataProto.concat(batch_list) batch.reorder(indices=torch.argsort(batch.batch["sample_order_placeholder"])) batch.pop("sample_order_placeholder") return batch.batch.pop("grouped_rewards") + def build_state_group(batch: "DataProto") -> "DataProto": batch.batch["sample_order_placeholder"] = torch.arange(batch.batch.batch_size[0], device=batch.batch.device) batch_group_by_traj_group: Dict[str, DataProto] = batch.group_by(keys="traj_group_id") @@ -135,7 +159,9 @@ def build_state_group(batch: "DataProto") -> "DataProto": for traj_group_id, traj_group_batch in batch_group_by_traj_group.items(): batch_group_by_state: Dict[str, DataProto] = traj_group_batch.group_by(keys="state_hash") for state, state_batch in batch_group_by_state.items(): - state_batch.non_tensor_batch["state_group_id"] = np.array([state] * state_batch.batch.batch_size[0], dtype=object) + state_batch.non_tensor_batch["state_group_id"] = np.array( + [state] * state_batch.batch.batch_size[0], dtype=object + ) merged.append(state_batch) state_batch_size = [len(m) for m in merged] merged = DataProto.concat(merged) @@ -148,15 +174,19 @@ def build_state_group(batch: "DataProto") -> "DataProto": merged.meta_info["metrics"] = metrics return merged + @torch.no_grad() def compute_response_level_rewards(batch: "DataProto", pipeline_config: AgenticConfig) -> "DataProto": + reward_metrics = {} if pipeline_config.adv_estimator == "gigpo": # ref: https://github.com/langfengQ/verl-agent/blob/e03bd502667c45172e8c093cc506db8438ae8ab5/gigpo/core_gigpo.py#L109 # step 1 episode_scores = torch.from_numpy(batch.non_tensor_batch["episode_scores"].astype(np.float32)) scores_to_group = DataProto.from_dict({"scores": episode_scores}) scores_to_group.non_tensor_batch = batch.non_tensor_batch - episode_rewards: torch.Tensor = grouped_reward_norm(scores_to_group, reward_normalization=pipeline_config.reward_normalization) + episode_rewards: torch.Tensor = agentic_reward_norm( + scores_to_group, reward_normalization=pipeline_config.reward_normalization + ) # step 2 batch = build_state_group(batch=batch) @@ -164,23 +194,44 @@ def compute_response_level_rewards(batch: "DataProto", pipeline_config: AgenticC # step 3 scores_to_group = DataProto.from_dict({"scores": batch.batch["step_rewards"]}) scores_to_group.non_tensor_batch = batch.non_tensor_batch - step_rewards: torch.Tensor = grouped_reward_norm(batch=scores_to_group, - reward_normalization=RewardNormalizationConfig(grouping="state_group_id", - method=pipeline_config.reward_normalization.method)) + step_rewards: torch.Tensor = agentic_reward_norm( + batch=scores_to_group, + reward_normalization=RewardNormalizationConfig( + grouping="state_group_id", method=pipeline_config.reward_normalization.method + ), + ) - batch.batch["response_level_rewards"] = pipeline_config.episode_reward_weight * episode_rewards + pipeline_config.step_reward_weight * step_rewards + batch.batch["response_level_rewards"] = ( + pipeline_config.episode_reward_weight * episode_rewards + pipeline_config.step_reward_weight * step_rewards + ) batch.batch["episode_rewards_norm"] = episode_rewards batch.batch["step_rewards_norm"] = step_rewards elif pipeline_config.adv_estimator == "step_reinforce": scores_to_group = DataProto.from_dict({"scores": batch.batch["step_rewards"]}) scores_to_group.non_tensor_batch = batch.non_tensor_batch - batch.batch["response_level_rewards"] = grouped_reward_norm(scores_to_group, reward_normalization=pipeline_config.reward_normalization) + batch.batch["response_level_rewards"] = agentic_reward_norm( + scores_to_group, reward_normalization=pipeline_config.reward_normalization + ) else: scores_to_group = DataProto.from_dict({"scores": batch.batch["scores"].clone().sum(dim=-1)}) scores_to_group.non_tensor_batch = batch.non_tensor_batch - batch.batch["response_level_rewards"] = grouped_reward_norm(scores_to_group, reward_normalization=pipeline_config.reward_normalization) + batch.batch["response_level_rewards"] = agentic_reward_norm( + scores_to_group, reward_normalization=pipeline_config.reward_normalization + ) + + # 加上clip + if pipeline_config.reward_clip: + reward_metrics["critic/reward_clip_frac"] = compute_clip_fraction( + values=batch.batch["response_level_rewards"], + clip_min=-pipeline_config.reward_clip, + clip_max=pipeline_config.reward_clip, + ) + batch.batch["response_level_rewards"] = torch.clamp( + batch.batch["response_level_rewards"], min=-pipeline_config.reward_clip, max=pipeline_config.reward_clip + ) + + return batch, reward_metrics - return batch @torch.no_grad() def get_agentic_response_level_mask(data: "DataProto", pipeline_config: AgenticConfig): @@ -228,6 +279,7 @@ def dump_frames_as_gif(filename, frames, duration=0.2): print_only_once = True pass + def dump_rollout_trajectories(path, global_step, data: DataProto): """ Dumps rollout trajectories to persistent storage. @@ -256,77 +308,14 @@ def dump_rollout_trajectories(path, global_step, data: DataProto): [data.non_tensor_batch.pop(item[0]) for item in columns_config if item[0] in data.non_tensor_batch] data_cnt = len(data) - write_data['global_step'] = [global_step] * data_cnt - columns_config.append(['global_step','bigint']) + write_data["global_step"] = [global_step] * data_cnt + columns_config.append(["global_step", "bigint"]) for checker, func in DUMPING_FUNC: if checker(path): p = multiprocessing.Process(target=func, args=(path, write_data, columns_config), daemon=False) p.start() -def compute_agentic_reinforce_return(token_level_rewards: torch.Tensor, gamma: torch.Tensor, lambd: torch.Tensor, mask: Optional[torch.Tensor] = None): - """ - 计算 REINFORCE 的 return,支持按 mask 分段 discount 衰减。 - 每段内所有位置获得相同的折扣累积值(从该段最后位置开始累积)。 - - Args: - token_level_rewards: [batch_size, seq_len] token 级别的奖励 - gamma: discount factor - lambd: lambda 参数(当前未使用,保留以兼容接口) - mask: [batch_size, seq_len] mask,1表示有效位置,0表示无效位置。如果为None,则对所有位置计算 - - Returns: - advantages: [batch_size, seq_len] advantages - returns: [batch_size, seq_len] returns - """ - with torch.no_grad(): - batch_size, gen_len = token_level_rewards.shape - device = token_level_rewards.device - returns = torch.zeros_like(token_level_rewards, dtype=torch.float32) - - # 如果没有提供 mask,则对所有位置计算(向后兼容) - if mask is None: - mask = torch.ones_like(token_level_rewards) - - # 确保 gamma 是标量 - gamma_val = gamma.item() if torch.is_tensor(gamma) else gamma - - # 对每个样本分别处理 - for b in range(batch_size): - sample_mask = mask[b] # [seq_len] - sample_rewards = token_level_rewards[b] # [seq_len] - - # 找到所有连续的1的段 - # 使用 diff 找到边界:1->0 和 0->1 的位置 - diff = torch.diff(sample_mask.float(), prepend=torch.tensor([0.0], device=device)) - - # 找到段的开始位置(0->1,diff==1) - segment_starts = torch.where(diff == 1)[0] - - # 找到段的结束位置(1->0,diff==-1) - segment_ends = torch.where(diff == -1)[0] - - # 如果最后一个位置是1,需要添加结束位置 - if len(sample_mask) > 0 and sample_mask[-1] == 1: - segment_ends = torch.cat([segment_ends, torch.tensor([gen_len], device=device)]) - - # 计算该段从最后位置开始的累积折扣奖励 - cumulative_return = 0.0 - # 对每段分别计算 discounted return - for start, end in zip(segment_starts.flip(-1), segment_ends.flip(-1)): - start_idx = start.item() - end_idx = end.item() - segment_len = end_idx - start_idx - - cumulative_return = sample_rewards[end_idx-1].item() + gamma_val * cumulative_return - - # 该段内所有位置都设置为这个累积值 - returns[b, start_idx:end_idx] = cumulative_return - - advantages = returns - - return advantages, returns - @torch.no_grad() def agentic_compute_advantage( data: "DataProto", @@ -361,13 +350,9 @@ def agentic_compute_advantage( token_level_rewards=token_level_rewards, gamma=gamma, lambd=lambd ) elif adv_estimator in ["agentic_reinforce"]: - advantages, returns = compute_agentic_reinforce_return( - token_level_rewards=token_level_rewards, gamma=gamma, lambd=lambd, mask=response_mask - ) + raise NotImplementedError else: raise NotImplementedError - - data.batch["raw_advantages"] = advantages if whiten_advantages: # TODO whiten过程中是否要考虑response的长度? advantages = masked_whiten(values=advantages, mask=response_mask) @@ -380,4 +365,4 @@ def agentic_compute_advantage( data.batch["advantages"] = advantages data.batch["returns"] = returns - return data \ No newline at end of file + return data diff --git a/tests/utils/test_score_norm_fn.py b/tests/utils/test_score_norm_fn.py deleted file mode 100644 index 254d97899..000000000 --- a/tests/utils/test_score_norm_fn.py +++ /dev/null @@ -1,35 +0,0 @@ -# tests/test_get_score_normalize_fn.py -import pytest -import torch -from dataclasses import dataclass - -from roll.pipeline.agentic.utils import get_score_normalize_fn - - -@dataclass -class MockRNCfg: - grouping: str = "dummy" - method: str = "mean_std" - - -@pytest.mark.parametrize( - "method,input_tensor,expected", - [ - ("mean_std", torch.tensor([1.0, 2.0, 3.0]), torch.tensor([-1.0, 0.0, 1.0])), - ("mean", torch.tensor([1.0, 2.0, 3.0]), torch.tensor([-1.0, 0.0, 1.0])), - ("asym_clip", torch.tensor([1.0, 2.0, 3.0]), torch.tensor([-1.0, 0.0, 1.0]).clamp(-1, 3)), - ("identity", torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])), - ], -) -def test_get_score_normalize_fn(method, input_tensor, expected): - cfg = MockRNCfg(method=method) - fn = get_score_normalize_fn(cfg) - out = fn(input_tensor) - torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-4) - - -def test_single_element_fallback(): - cfg = MockRNCfg(method="mean_std") - fn = get_score_normalize_fn(cfg) - out = fn(torch.tensor([5.0])) - torch.testing.assert_close(out, torch.tensor([0.0])) \ No newline at end of file From 5d92cc0fe2d3a423623fac01f955743eed235102 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Tue, 2 Dec 2025 16:19:08 +0800 Subject: [PATCH 30/58] (feat): add agentic profile metrics. --- roll/distributed/executor/worker.py | 19 +- .../scheduler/rollout_scheduler.py | 12 +- roll/distributed/strategy/strategy.py | 14 +- roll/distributed/strategy/vllm_strategy.py | 52 ++- roll/pipeline/agentic/agentic_pipeline.py | 369 ++++++++++-------- .../agentic/agentic_rollout_pipeline.py | 8 +- roll/pipeline/agentic/environment_worker.py | 32 +- roll/utils/functionals.py | 49 ++- roll/utils/import_utils.py | 4 +- roll/utils/offload_states.py | 1 - 10 files changed, 374 insertions(+), 186 deletions(-) diff --git a/roll/distributed/executor/worker.py b/roll/distributed/executor/worker.py index 4f11f0cb5..ce162a2dd 100644 --- a/roll/distributed/executor/worker.py +++ b/roll/distributed/executor/worker.py @@ -3,7 +3,7 @@ import socket from concurrent import futures from dataclasses import dataclass -from typing import Dict +from typing import Dict, Optional, List import ray @@ -210,3 +210,20 @@ def add_lora(self, *args, **kwargs): def download_models(self, model_name_or_paths: set[str]): futures.wait([self.thread_executor.submit(download_model, model_name_or_path) for model_name_or_path in model_name_or_paths]) + + @register(dispatch_mode=Dispatch.DP_MP_COMPUTE) + def get_metrics(self, metric_names: Optional[List[str]] = None) -> DataProto: + """ + Get performance metrics from the strategy layer. + + Args: + metric_names: Optional list of specific metric names to filter + + Returns: + Dictionary of metric names to aggregated values + """ + if getattr(self, "strategy", None) is not None: + metrics = self.strategy.get_metrics(metric_names=metric_names) + else: + metrics = {} + return DataProto(meta_info={"metrics": metrics}) diff --git a/roll/distributed/scheduler/rollout_scheduler.py b/roll/distributed/scheduler/rollout_scheduler.py index 2dd959d65..8f301730c 100644 --- a/roll/distributed/scheduler/rollout_scheduler.py +++ b/roll/distributed/scheduler/rollout_scheduler.py @@ -1,5 +1,6 @@ import asyncio import random +import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -282,6 +283,9 @@ async def wait_a_episode(): self.pending_gets.update(pending) await wait_a_episode() + get_batch_return_start_time = time.time() + for d in ret: + d.meta_info["get_batch_return_start_time"] = get_batch_return_start_time return ret class RolloutScheduler: @@ -385,8 +389,14 @@ async def get_batch(self, data: DataProto, batch_size): return None metrics = {} - [append_to_dict(metrics, meta_info.meta_info["metrics"]) for meta_info in data_batch] + get_batch_return_start_time = None + for d_item in data_batch: + get_batch_return_start_time = d_item.meta_info.pop("get_batch_return_start_time", None) + append_to_dict(metrics, d_item.meta_info["metrics"]) + if get_batch_return_start_time is not None: + metrics["time/get_batch_cost_gqm"] = time.time() - get_batch_return_start_time metrics.update(await self.env_output_queue.collect_metrics.remote()) batch = DataProto.concat(data_batch) batch.meta_info["metrics"] = metrics + batch.meta_info["get_batch_return_start_time"] = time.time() return batch diff --git a/roll/distributed/strategy/strategy.py b/roll/distributed/strategy/strategy.py index 7c49c5c80..6d52d85e9 100644 --- a/roll/distributed/strategy/strategy.py +++ b/roll/distributed/strategy/strategy.py @@ -1,6 +1,6 @@ from abc import ABC from concurrent import futures -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -50,6 +50,18 @@ def get_data_input(self, batch: "DataProto") -> "DataProto": def generate(self, *args, **kwargs): raise NotImplementedError + def get_metrics(self, metric_names: Optional[List[str]] = None) -> Dict[str, float]: + """ + Get performance metrics from the strategy. + + Args: + metric_names: Optional list of specific metric names to filter + + Returns: + Dictionary of metric names to aggregated values + """ + return {} + def start_server(self, *args, **kwargs): raise NotImplementedError diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 00ecacb10..1ecb925e7 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -3,7 +3,9 @@ import gc import os import queue +import threading import time +from collections import defaultdict, deque from concurrent import futures from typing import Dict, List, Optional, Union @@ -19,11 +21,11 @@ from vllm.utils import random_uuid from roll.distributed.executor.worker import Worker -from roll.distributed.scheduler.protocol import DataProto +from roll.distributed.scheduler.protocol import DataProto, list_of_dict_to_dict_of_list from roll.distributed.strategy.strategy import InferenceStrategy from roll.third_party.vllm import LLM, AsyncLLM from roll.utils.collective import collective -from roll.utils.functionals import GenerateRequestType, concatenate_input_and_output +from roll.utils.functionals import GenerateRequestType, concatenate_input_and_output, reduce_metrics from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform @@ -51,6 +53,11 @@ def __init__(self, worker: Worker): self.request_metas = {} self.running = False + + # Metrics snapshot infrastructure + self._metrics_snapshots = deque(maxlen=3600) + self._metrics_snapshot_interval = 1.0 # Snapshot every 1 second + self._metrics_thread = None def initialize(self, model_provider): set_seed(seed=self.worker.pipeline_config.seed) @@ -125,6 +132,12 @@ def initialize(self, model_provider): self.is_model_in_gpu = True + self._metrics_thread = threading.Thread( + target=self._collect_metrics_snapshot, + name="metrics-collection" + ) + self._metrics_thread.start() + def op_compute_log_probs(self, logits: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor): """ vllm实现compute log probs在这里实现即可 @@ -443,6 +456,41 @@ def update_parameter_in_bucket(self, model_update_name, meta_infos, buffer, rank def add_lora(self, peft_config): self.model.add_lora(peft_config) + def _collect_metrics_snapshot(self): + """Collect metrics snapshots periodically in a background thread.""" + while True: + raw_metrics = self.model.get_metrics() + snapshot = { + 'vllm/kv_cache_usage_perc_max': [], + 'vllm/num_requests_waiting_max': [], + 'vllm/num_preemptions_max': [] + } + for metric in raw_metrics: + if metric.name == "vllm:kv_cache_usage_perc": + snapshot['vllm/kv_cache_usage_perc_max'].append(metric.value) + elif metric.name == "vllm:num_requests_waiting": + snapshot['vllm/num_requests_waiting_max'].append(metric.value) + elif metric.name == "vllm:num_preemptions": + snapshot['vllm/num_preemptions_max'].append(metric.value) + self._metrics_snapshots.append(snapshot) + + time.sleep(self._metrics_snapshot_interval) + + def get_metrics(self, metric_names: Optional[List[str]] = None) -> Dict[str, float]: + """ + Get aggregated metrics for the time interval since last call. + + Args: + metric_names: Optional list of specific metric names to filter + + Returns: + Dictionary of metric names to aggregated values + """ + if not self._metrics_snapshots: + return {} + metrics_snapshots = list_of_dict_to_dict_of_list(self._metrics_snapshots) + self._metrics_snapshots.clear() + return reduce_metrics(metrics_snapshots) def gather_unpadded_input_ids(input_ids: torch.Tensor, attention_mask: torch.Tensor): gathered_input_ids = [ids[mask.bool()].tolist() for ids, mask in zip(input_ids, attention_mask)] diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 726affd11..923658c07 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -1,6 +1,7 @@ import json import os.path import random +import time from typing import Any, Dict, List import numpy as np @@ -136,183 +137,206 @@ def run(self): continue logger.info(f"pipeline rollout global step {global_step} start...") metrics = {} - with tps_timer: - if self.pipeline_config.adv_estimator == "gae": - self.critic.offload_states(blocking=True) - self.actor_train.offload_states(blocking=True) - - ray.get(self.train_rollout_scheduler.suspend.remote()) - if self.pipeline_config.async_generation_ratio > 0: - self.actor_infer.stop_server() - model_update_metrics: Dict = self.model_update(global_step) - metrics.update(model_update_metrics) - if self.pipeline_config.async_generation_ratio > 0: - self.actor_infer.start_server(data=DataProto(meta_info={"global_step": global_step, "is_offload_states": False})) - else: - self.actor_infer.start_server(data=DataProto(meta_info={"global_step": global_step, "is_offload_states": True})) - - batch: DataProto = DataProto() - batch.meta_info = {"global_step": global_step} - - if global_step % self.pipeline_config.eval_steps == 0: - metrics.update(self.val(global_step=global_step)) - - with Timer(name="rollout", logger=None) as rollout_timer: - batch.meta_info["is_offload_states"] = True - batch = ray.get(self.train_rollout_scheduler.get_batch.remote(batch, self.pipeline_config.rollout_batch_size)) - dump_rollout_trajectories(self.pipeline_config.rollout_dump_dir, global_step, batch) - - metrics["time/rollout"] = rollout_timer.last - metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - batch.meta_info["global_step"] = global_step - if not (self.pipeline_config.async_generation_ratio > 0): - self.actor_infer.stop_server() - - batch = compute_discounted_returns(batch, self.pipeline_config.adv_estimator, self.pipeline_config.step_reward_gamma) - - batch = self.adjust_batch(batch, mode=self.pipeline_config.batch_adjust_mode) - metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - - with Timer(name="cal_ref_log_probs", logger=None) as cal_timer: - ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs(batch, blocking=False) - ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) - ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") - batch = batch.union(ref_log_probs) - avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) - metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) - metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) - metrics["time/ref_log_probs_values_reward"] = cal_timer.last - - with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: - # TODO: use engine log_probs as old_log_probs - batch.meta_info["is_offload_states"] = False - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) - if self.pipeline_config.adv_estimator == "gae": - values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + + # Add overall step timing + with Timer(name="pipeline_step_total", logger=None) as step_timer: + with tps_timer: if self.pipeline_config.adv_estimator == "gae": - values = DataProto.materialize_concat(data_refs=values_refs) - batch = batch.union(values) - metrics.update(reduce_metrics(values.meta_info.pop("metrics", {}))) - batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] - avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) - metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) - - agg_entropy = agg_loss( - loss_mat=old_log_probs.batch["entropy"], - loss_mask=batch.batch["response_mask"][:, 1:], - loss_agg_mode="token-mean", - ) - metrics.update({"critic/entropy/mean": agg_entropy.item()}) - - metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) - metrics["time/old_log_probs_values"] = cal_old_logpb_timer.last - - # TODO 当前这个还没用处 - with Timer(name="cal_response_level_mask", logger=None) as timer: - # TODO 补充完善的过滤要求,不同环境需要维持统一过滤标识 - batch, mask_metrics = get_agentic_response_level_mask(batch, self.pipeline_config) - metrics.update(mask_metrics) - metrics["time/cal_response_level_mask"] = timer.last - - with Timer(name="cal_response_norm_rewards", logger=None) as timer: - # Rewards need to be processed after grouping - # We can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to compute rewards / advantages - # The compute_response_level_rewards function injects a response_level_rewards key into batch.batch. - batch, reward_metrics = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) + self.critic.offload_states(blocking=True) + self.actor_train.offload_states(blocking=True) + + ray.get(self.train_rollout_scheduler.suspend.remote()) + if self.pipeline_config.async_generation_ratio > 0: + self.actor_infer.stop_server() + + with Timer(name="model_update", logger=None) as model_update_timer: + model_update_metrics: Dict = self.model_update(global_step) + metrics["time/step_model_update"] =model_update_timer.last + + metrics.update(model_update_metrics) + if self.pipeline_config.async_generation_ratio > 0: + self.actor_infer.start_server(data=DataProto(meta_info={"global_step": global_step, "is_offload_states": False})) + else: + self.actor_infer.start_server(data=DataProto(meta_info={"global_step": global_step, "is_offload_states": True})) + + batch: DataProto = DataProto() + batch.meta_info = {"global_step": global_step} + + if self.pipeline_config.eval_steps > 0 and global_step % self.pipeline_config.eval_steps == 0: + with Timer(name="val", logger=None) as val_timer: + metrics.update(self.val(global_step=global_step)) + metrics["time/step_val"] = val_timer.last + + with Timer(name="rollout", logger=None) as rollout_timer: + batch.meta_info["is_offload_states"] = True + batch = ray.get(self.train_rollout_scheduler.get_batch.remote(batch, self.pipeline_config.rollout_batch_size)) + if "get_batch_return_start_time" in batch.meta_info: + metrics["time/get_batch_cost_train"] = time.time() - batch.meta_info.pop("get_batch_return_start_time") + actor_infer_metrics = self.actor_infer.get_metrics() + metrics.update(reduce_metrics(actor_infer_metrics.meta_info.pop("metrics", {}))) + + dump_rollout_trajectories(self.pipeline_config.rollout_dump_dir, global_step, batch) + + metrics["time/step_rollout"] = rollout_timer.last metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - metrics.update(reward_metrics) - metrics["time/cal_norm_rewards"] = timer.last - - with Timer(name="cal_token_reward", logger=None) as timer: - # Expand compute_response_level_rewards and add kl_penalty. - # batch, kl_metrics = apply_kl_penalty(data=batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.pipeline_config.kl_penalty) - batch, token_level_metrics = compute_token_reward(batch, self.pipeline_config, self.kl_ctrl) - metrics.update(token_level_metrics) - metrics["time/cal_token_reward"] = timer.last - - with Timer(name="compute_advantage", logger=None) as timer: - # Is the advantage calculated globally across the batch, or within each group? - batch = agentic_compute_advantage( - data=batch, - gamma=self.pipeline_config.gamma, - lambd=self.pipeline_config.lambd, - adv_estimator=self.pipeline_config.adv_estimator, - advantage_clip=self.pipeline_config.advantage_clip, - whiten_advantages=self.pipeline_config.whiten_advantages, - whiten_rewards=self.pipeline_config.whiten_rewards, - ) - metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - metrics["time/adv"] = timer.last - - if self.pipeline_config.adv_estimator == "gae": - critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False) - - # implement critic warmup - if self.pipeline_config.critic_warmup <= global_step: - # update actor - actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False) - actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs) - metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {}))) - - if self.pipeline_config.adv_estimator == "gae": - critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs) - metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {}))) - tps_timer.push_units_processed(n=torch.sum(batch.batch["attention_mask"]).detach().item()) - - data_metrics = compute_data_metrics(batch=batch) - metrics.update(data_metrics) - metrics["system/tps"] = tps_timer.mean_throughput - metrics["system/samples"] = (global_step + 1) * self.pipeline_config.rollout_batch_size + batch.meta_info["global_step"] = global_step + if not (self.pipeline_config.async_generation_ratio > 0): + self.actor_infer.stop_server() - # do ckpt - self.state.step = global_step - self.state.log_history.append(metrics) + batch = compute_discounted_returns(batch, self.pipeline_config.adv_estimator, self.pipeline_config.step_reward_gamma) - self.do_checkpoint(global_step=global_step) - - self.tracker.log(values=metrics, step=global_step) + batch = self.adjust_batch(batch, mode=self.pipeline_config.batch_adjust_mode) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) - if global_step % self.pipeline_config.logging_steps == 0: - if int(os.environ.get("RAY_PROFILING", "0")): - timeline_dir = os.path.join(self.pipeline_config.profiler_output_dir, "timeline") - os.makedirs(timeline_dir, exist_ok=True) - ray.timeline( - filename=os.path.join(timeline_dir, f"timeline-step-{global_step}.json"), - ) - - log_res = [] - batch_grouped = batch.group_by(keys="traj_id") - for group_name, group_batch in batch_grouped.items(): - prompt_mask = group_batch.batch["prompt_mask"] - non_prompt_mask = torch.logical_not(group_batch.batch["prompt_mask"]) * group_batch.batch["attention_mask"] - input_ids = group_batch.batch["input_ids"] - prompt_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(prompt_mask)] - response_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(non_prompt_mask)] - prompts = self.tokenizer.batch_decode(prompt_ids_list, skip_special_tokens=False) - responses = self.tokenizer.batch_decode(response_ids_list, skip_special_tokens=False) - episode_scores = group_batch.non_tensor_batch["episode_scores"].tolist() - step_scores = group_batch.non_tensor_batch["step_scores"].tolist() - if not isinstance(step_scores[0], float): - step_scores = [t.tolist() for t in step_scores] - - log_item = [] - for prompt, response, episode_score, step_score in zip( - prompts, responses, episode_scores, step_scores - ): - log_item.append( - { - "prompt": prompt, - "response": response, - "episode_score": episode_score, - "step_score": step_score, - } + with Timer(name="cal_ref_log_probs", logger=None) as cal_timer: + ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs(batch, blocking=False) + ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) + avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) + metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) + metrics["time/step_ref_log_probs_values_reward"] = cal_timer.last + + with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: + # TODO: use engine log_probs as old_log_probs + batch.meta_info["is_offload_states"] = False + old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) + if self.pipeline_config.adv_estimator == "gae": + values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) + old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + if self.pipeline_config.adv_estimator == "gae": + values = DataProto.materialize_concat(data_refs=values_refs) + batch = batch.union(values) + metrics.update(reduce_metrics(values.meta_info.pop("metrics", {}))) + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) + + agg_entropy = agg_loss( + loss_mat=old_log_probs.batch["entropy"], + loss_mask=batch.batch["response_mask"][:, 1:], + loss_agg_mode="token-mean", ) - log_res.append(log_item) - if len(log_res) >= 10: - break - logger.info(json.dumps(log_res, ensure_ascii=False)) - logger.info(json.dumps(metrics, ensure_ascii=False)) + metrics.update({"critic/entropy/mean": agg_entropy.item()}) + + metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) + metrics["time/step_old_log_probs_values"] = cal_old_logpb_timer.last + + # TODO 当前这个还没用处 + with Timer(name="cal_response_level_mask", logger=None) as timer: + # TODO 补充完善的过滤要求,不同环境需要维持统一过滤标识 + batch, mask_metrics = get_agentic_response_level_mask(batch, self.pipeline_config) + metrics.update(mask_metrics) + metrics["time/step_cal_response_level_mask"] = timer.last + + with Timer(name="cal_response_norm_rewards", logger=None) as timer: + # Rewards need to be processed after grouping + # We can group by tag(env_type)/traj_group_id(group)/batch(rollout_batch)... to compute rewards / advantages + # The compute_response_level_rewards function injects a response_level_rewards key into batch.batch. + batch, reward_metrics = compute_response_level_rewards(batch=batch, pipeline_config=self.pipeline_config) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics.update(reward_metrics) + metrics["time/step_cal_norm_rewards"] = timer.last + + with Timer(name="cal_token_reward", logger=None) as timer: + # Expand compute_response_level_rewards and add kl_penalty. + # batch, kl_metrics = apply_kl_penalty(data=batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.pipeline_config.kl_penalty) + batch, token_level_metrics = compute_token_reward(batch, self.pipeline_config, self.kl_ctrl) + metrics.update(token_level_metrics) + metrics["time/step_cal_token_reward"] = timer.last + + with Timer(name="compute_advantage", logger=None) as timer: + # Is the advantage calculated globally across the batch, or within each group? + batch = agentic_compute_advantage( + data=batch, + gamma=self.pipeline_config.gamma, + lambd=self.pipeline_config.lambd, + adv_estimator=self.pipeline_config.adv_estimator, + advantage_clip=self.pipeline_config.advantage_clip, + whiten_advantages=self.pipeline_config.whiten_advantages, + whiten_rewards=self.pipeline_config.whiten_rewards, + ) + metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) + metrics["time/step_adv"] = timer.last + + with Timer(name="train_timer", logger=None) as train_timer: + if self.pipeline_config.adv_estimator == "gae": + critic_train_metrics_refs: List[ray.ObjectRef] = self.critic.train_step(batch, blocking=False) + + # implement critic warmup + if self.pipeline_config.critic_warmup <= global_step: + # update actor + actor_train_metrics_refs = self.actor_train.train_step(batch, blocking=False) + actor_train_metrics: DataProto = DataProto.materialize_concat(data_refs=actor_train_metrics_refs) + metrics.update(reduce_metrics(actor_train_metrics.meta_info.pop("metrics", {}))) + + if self.pipeline_config.adv_estimator == "gae": + critic_train_metrics = DataProto.materialize_concat(data_refs=critic_train_metrics_refs) + metrics.update(reduce_metrics(critic_train_metrics.meta_info.pop("metrics", {}))) + tps_timer.push_units_processed(n=torch.sum(batch.batch["attention_mask"]).detach().item()) + metrics["time/step_train"] = train_timer.last + + with Timer(name="compute_data_metrics", logger=None) as data_metrics_timer: + data_metrics = compute_data_metrics(batch=batch) + + metrics["time/step_compute_data_metrics"] = data_metrics_timer.last + metrics.update(data_metrics) + metrics["system/tps"] = tps_timer.mean_throughput + metrics["system/samples"] = (global_step + 1) * self.pipeline_config.rollout_batch_size + + # do ckpt + self.state.step = global_step + self.state.log_history.append(metrics) + + self.do_checkpoint(global_step=global_step) + + with Timer(name="log", logger=None) as log_timer: + if self.pipeline_config.logging_steps > 0 and global_step % self.pipeline_config.logging_steps == 0: + if int(os.environ.get("RAY_PROFILING", "0")): + timeline_dir = os.path.join(self.pipeline_config.profiler_output_dir, "timeline") + os.makedirs(timeline_dir, exist_ok=True) + ray.timeline( + filename=os.path.join(timeline_dir, f"timeline-step-{global_step}.json"), + ) + + log_res = [] + batch_grouped = batch.group_by(keys="traj_id") + for group_name, group_batch in batch_grouped.items(): + prompt_mask = group_batch.batch["prompt_mask"] + non_prompt_mask = torch.logical_not(group_batch.batch["prompt_mask"]) * group_batch.batch["attention_mask"] + input_ids = group_batch.batch["input_ids"] + prompt_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(prompt_mask)] + response_ids_list = [input_ids[i][mask.bool()] for i, mask in enumerate(non_prompt_mask)] + prompts = self.tokenizer.batch_decode(prompt_ids_list, skip_special_tokens=False) + responses = self.tokenizer.batch_decode(response_ids_list, skip_special_tokens=False) + episode_scores = group_batch.non_tensor_batch["episode_scores"].tolist() + step_scores = group_batch.non_tensor_batch["step_scores"].tolist() + if not isinstance(step_scores[0], float): + step_scores = [t.tolist() for t in step_scores] + + log_item = [] + for prompt, response, episode_score, step_score in zip( + prompts, responses, episode_scores, step_scores + ): + log_item.append( + { + "prompt": prompt, + "response": response, + "episode_score": episode_score, + "step_score": step_score, + } + ) + log_res.append(log_item) + if len(log_res) >= 10: + break + logger.info(json.dumps(log_res, ensure_ascii=False)) + logger.info(json.dumps(metrics, ensure_ascii=False)) + + metrics["time/step_log"] = log_timer.last + + metrics["time/step_total"] = step_timer.last + self.tracker.log(values=metrics, step=global_step) logger.info(f"pipeline step {global_step} finished") global_step += 1 @@ -332,6 +356,9 @@ def val(self, global_step): ray.get(self.val_dataset_manager.reset.remote()) eval_batch = ray.get(self.val_rollout_scheduler.get_batch.remote(batch, self.pipeline_config.val_batch_size)) + if "get_batch_return_start_time" in eval_batch.meta_info: + metrics["time/get_batch_cost_val"] = time.time() - eval_batch.meta_info.pop("get_batch_return_start_time") + dump_rollout_trajectories(self.pipeline_config.rollout_dump_dir, global_step, eval_batch) eval_metrics = reduce_metrics(eval_batch.meta_info.get("metrics", {})) eval_score = get_episode_scores(eval_batch) diff --git a/roll/pipeline/agentic/agentic_rollout_pipeline.py b/roll/pipeline/agentic/agentic_rollout_pipeline.py index d2c875923..0586d3ec9 100644 --- a/roll/pipeline/agentic/agentic_rollout_pipeline.py +++ b/roll/pipeline/agentic/agentic_rollout_pipeline.py @@ -1,5 +1,6 @@ import json import os.path +import time from itertools import count from typing import Any @@ -70,7 +71,12 @@ def run(self): if batch is None: break - metrics["time/rollout"] = rollout_timer.last + if "get_batch_return_start_time" in batch.meta_info: + metrics["time/get_batch_cost_train"] = time.time() - batch.meta_info.pop("get_batch_return_start_time") + actor_infer_metrics: DataProto = self.actor_infer.get_metrics() + metrics.update(reduce_metrics(actor_infer_metrics.meta_info.pop("metrics", {}))) + + metrics["time/step_rollout"] = rollout_timer.last eval_metrics = reduce_metrics(batch.meta_info.get("metrics", {})) eval_score = batch.batch["scores"].sum(-1) eval_metrics["score/mean"] = torch.mean(eval_score).detach().item() diff --git a/roll/pipeline/agentic/environment_worker.py b/roll/pipeline/agentic/environment_worker.py index bbc12649b..bd4ede7b6 100644 --- a/roll/pipeline/agentic/environment_worker.py +++ b/roll/pipeline/agentic/environment_worker.py @@ -1,5 +1,6 @@ import asyncio import copy +import os import threading from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, Optional @@ -7,6 +8,8 @@ from codetiming import Timer from transformers import PreTrainedTokenizer, ProcessorMixin +from roll.utils.context_managers import local_profiler + from roll.pipeline.agentic.env_manager.base_env_manager import BaseEnvManager from roll.distributed.executor.worker import Worker from roll.distributed.scheduler.decorator import Dispatch, register @@ -36,6 +39,7 @@ def __init__(self, worker_config: EnvManagerConfig): self.env_configs: Dict[int, Dict] = worker_config.env_configs[self.rank] self.thread_lock = threading.Lock() self.output_queue = None + self.mode = "train" @register(dispatch_mode=Dispatch.ONE_TO_ALL, clear_cache=False) async def initialize(self, @@ -47,6 +51,7 @@ async def initialize(self, super().initialize(pipeline_config) self.output_queue = output_queue + self.mode = mode model_name_or_path = download_model(self.worker_config.model_args.model_name_or_path) self.tokenizer = default_tokenizer_provider(self.worker_config.model_args, model_name_or_path) self.processor = default_processor_provider(self.worker_config.model_args, model_name_or_path) @@ -88,14 +93,29 @@ def create_env_manager(env_id, env_config): @register(dispatch_mode=Dispatch.ONE_TO_ALL, clear_cache=False) async def run_rollout_loop(self, seed): + # Set environment variables for profiler context + os.environ["roll_EXEC_FUNC_NAME"] = "run_rollout_loop" + os.environ["WORKER_NAME"] = f"EnvironmentWorker_{self.rank}" + loop = asyncio.get_event_loop() pool = ThreadPoolExecutor(max_workers=len(self.env_managers)) - await asyncio.gather( - *[ - loop.run_in_executor(pool, env_manager.run_rollout_loop, DataProto(meta_info={"seed": seed})) - for env_manager in self.env_managers.values() - ] - ) + + def run_with_profiler(env_manager, data_proto): + with local_profiler(): + return env_manager.run_rollout_loop(data_proto) + + def run_without_profiler(env_manager, data_proto): + return env_manager.run_rollout_loop(data_proto) + + tasks = [] + for env_id, env_manager in self.env_managers.items(): + # Only profile the first env_manager (env_id=0) on rank=0 + run_func = run_without_profiler + if self.rank == 0 and env_id == 0: + run_func = run_with_profiler + tasks.append(loop.run_in_executor(pool, run_func, env_manager, DataProto(meta_info={"seed": seed}))) + + await asyncio.gather(*tasks) pool.shutdown() @register(dispatch_mode=Dispatch.ONE_TO_ALL, clear_cache=False) diff --git a/roll/utils/functionals.py b/roll/utils/functionals.py index 4097ce30c..e02876567 100644 --- a/roll/utils/functionals.py +++ b/roll/utils/functionals.py @@ -350,8 +350,55 @@ def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift def reduce_metrics(metrics: dict, reduce_func=np.mean) -> dict: + """ + Reduce metrics with enhanced aggregation support based on metric name suffixes. + + Supported suffixes: + - _mean: arithmetic mean (default) + - _max: maximum value + - _min: minimum value + - _p50: 50th percentile (median) + - _p99: 99th percentile + - _std: standard deviation + - _sum: sum of all values + + Args: + metrics: Dictionary of metric names to lists/tensors of values + reduce_func: Default reduction function (used for metrics without suffix) + + Returns: + Dictionary with reduced metric values + """ + import numpy as np + + def _parse_suffix(metric_name): + """Parse aggregation method from metric name suffix.""" + if metric_name.endswith('_mean'): + return np.mean + elif metric_name.endswith('_max'): + return np.max + elif metric_name.endswith('_min'): + return np.min + elif metric_name.endswith('_p50'): + return lambda x: np.percentile(x, 50) + elif metric_name.endswith('_p99'): + return lambda x: np.percentile(x, 99) + elif metric_name.endswith('_std'): + return np.std + elif metric_name.endswith('_sum'): + return np.sum + else: + return reduce_func + for key, val in metrics.items(): - metrics[key] = reduce_func(val) + if isinstance(val, (list, tuple, np.ndarray)) and len(val) > 0: + # Use suffix-based aggregation if available + aggregation_func = _parse_suffix(key) + metrics[key] = float(aggregation_func(val)) + else: + # Fallback to default reduction function + metrics[key] = reduce_func(val) + return metrics diff --git a/roll/utils/import_utils.py b/roll/utils/import_utils.py index 1d8a6beba..c8b0e1d65 100644 --- a/roll/utils/import_utils.py +++ b/roll/utils/import_utils.py @@ -1,6 +1,7 @@ import importlib from importlib.util import find_spec from typing import Any, Optional +import traceback from roll.utils.logging import get_logger @@ -20,6 +21,7 @@ def can_import_class(class_path: str) -> bool: return True except Exception as e: logger.error(f"Failed to import class {class_path}: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") return False @@ -30,4 +32,4 @@ def safe_import_class(class_path: str) -> Optional[Any]: cls = getattr(module, class_name) return cls else: - return None + return None \ No newline at end of file diff --git a/roll/utils/offload_states.py b/roll/utils/offload_states.py index af69bdf95..5bbf20aab 100644 --- a/roll/utils/offload_states.py +++ b/roll/utils/offload_states.py @@ -4,7 +4,6 @@ import torch from torch import Tensor from transformers import PreTrainedModel -from trl import AutoModelForCausalLMWithValueHead from roll.platforms import current_platform From 307924e5450768e72fd0dce73b5b0c95530ac33e Mon Sep 17 00:00:00 2001 From: "tianhe.lzd" Date: Tue, 11 Nov 2025 18:06:54 +0800 Subject: [PATCH 31/58] (feat): sglang 054 patch. --- roll/third_party/sglang/__init__.py | 3 + .../third_party/sglang/v054_patch/__init__.py | 2 + roll/third_party/sglang/v054_patch/engine.py | 113 ++++++++++ .../sglang/v054_patch/model_runner.py | 200 ++++++++++++++++++ .../sglang/v054_patch/scheduler.py | 98 +++++++++ .../sglang/v054_patch/tokenizer_manager.py | 127 +++++++++++ .../sglang/v054_patch/tp_worker.py | 58 +++++ 7 files changed, 601 insertions(+) create mode 100644 roll/third_party/sglang/v054_patch/__init__.py create mode 100644 roll/third_party/sglang/v054_patch/engine.py create mode 100644 roll/third_party/sglang/v054_patch/model_runner.py create mode 100644 roll/third_party/sglang/v054_patch/scheduler.py create mode 100644 roll/third_party/sglang/v054_patch/tokenizer_manager.py create mode 100644 roll/third_party/sglang/v054_patch/tp_worker.py diff --git a/roll/third_party/sglang/__init__.py b/roll/third_party/sglang/__init__.py index 19c13ee31..e3d796903 100644 --- a/roll/third_party/sglang/__init__.py +++ b/roll/third_party/sglang/__init__.py @@ -13,5 +13,8 @@ elif sgl.__version__ == '0.5.2': from roll.third_party.sglang import v052_patch patch = v052_patch +elif sgl.__version__ == '0.5.4.post2': + from roll.third_party.sglang import v054_patch + patch = v054_patch else: raise NotImplementedError(f"Scale aligner version sglang:{sgl.__version__} is not supported.") \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/__init__.py b/roll/third_party/sglang/v054_patch/__init__.py new file mode 100644 index 000000000..fa4bec152 --- /dev/null +++ b/roll/third_party/sglang/v054_patch/__init__.py @@ -0,0 +1,2 @@ +from . import engine +from . import scheduler \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/engine.py b/roll/third_party/sglang/v054_patch/engine.py new file mode 100644 index 000000000..99489347c --- /dev/null +++ b/roll/third_party/sglang/v054_patch/engine.py @@ -0,0 +1,113 @@ +import asyncio +from sglang.srt.entrypoints.engine import Engine + +from roll.third_party.sglang.io_struct import ( + SetupCollectiveGroupReqInput, + BroadcastBucketReqInput, + BroadcastParameterReqInput, + UpdateParameterInBucketReqInput, + UpdateParameterReqInput, +) +import sglang.srt.entrypoints.engine as engine_module + + +class EngineSA(Engine): + + def setup_collective_group( + self, + comm_plan: str, + backend: str, + rank_in_cluster: int, + ): + obj = SetupCollectiveGroupReqInput( + comm_plan=comm_plan, + backend=backend, + rank_in_cluster=rank_in_cluster, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.setup_collective_group(obj, None) + ) + + def broadcast_bucket( + self, + src_pp_rank: int, + meta_infos: dict, + bucket_size: int, + ): + obj = BroadcastBucketReqInput( + src_pp_rank=src_pp_rank, + meta_infos=meta_infos, + bucket_size=bucket_size, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.broadcast_bucket(obj, None) + ) + + def broadcast_parameter( + self, + src_pp_rank, + dtype, + shape, + parameter_name + ): + obj = BroadcastParameterReqInput( + src_pp_rank=src_pp_rank, + dtype=dtype, + shape=shape, + parameter_name=parameter_name, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.broadcast_parameter(obj, None) + ) + + def update_parameter( + self, + parameter_name, + weight, + ranks_in_worker + ): + obj = UpdateParameterReqInput( + parameter_name=parameter_name, + weight=weight, + ranks_in_worker=ranks_in_worker, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_parameter(obj, None) + ) + + def update_parameter_in_bucket( + self, + meta_infos, + buffer, + ranks_in_worker + ): + """Initialize parameter update group.""" + obj = UpdateParameterInBucketReqInput( + meta_infos=meta_infos, + buffer=buffer, + ranks_in_worker=ranks_in_worker, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_parameter_in_bucket(obj, None) + ) + +class _roll_launch_subprocesses(object): + def __init__(self, _launch_subprocesses): + self._launch_subprocesses = _launch_subprocesses + + def __call__(self, *args, **kwargs): + import sys + from roll.third_party.sglang.v054_patch.tokenizer_manager import TokenizerManagerSA + from roll.third_party.sglang.v054_patch.scheduler import run_scheduler_process + + sys.modules['sglang.srt.entrypoints.engine'].__dict__['TokenizerManager'] = TokenizerManagerSA + sys.modules['sglang.srt.entrypoints.engine'].__dict__['run_scheduler_process'] = run_scheduler_process + return self._launch_subprocesses(*args, **kwargs) + + +engine_module._launch_subprocesses = _roll_launch_subprocesses(engine_module._launch_subprocesses) \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/model_runner.py b/roll/third_party/sglang/v054_patch/model_runner.py new file mode 100644 index 000000000..ce1832d8d --- /dev/null +++ b/roll/third_party/sglang/v054_patch/model_runner.py @@ -0,0 +1,200 @@ +import logging +from dataclasses import dataclass +import torch +import torch.distributed as dist +import datetime + +from roll.platforms import current_platform + + +from sglang.srt.model_executor.model_runner import ModelRunner, UNBALANCED_MODEL_LOADING_TIMEOUT_S +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp +from sglang.srt.distributed import get_tp_group +from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer +from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state +from sglang.srt.model_loader import get_model +from sglang.srt.offloader import get_offloader + +from sglang.srt.utils import ( + get_available_gpu_memory, + monkey_patch_vllm_gguf_config, + set_cuda_arch, +) + +from roll.utils.collective import collective +from roll.utils.functionals import get_dist_info_from_comm_plan +from roll.platforms import current_platform + +logger = logging.getLogger(__name__) + + +class ModelRunnerSA(ModelRunner): + def load_model(self): + before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" + ) + + # This can reduce thread conflicts and speed up weight loading. + if self.device != "cpu": + torch.set_num_threads(1) + if self.device == current_platform.device_type: + if current_platform.get_device_capability()[0] < 8: + logger.info( + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." + ) + self.server_args.dtype = "float16" + self.model_config.dtype = torch.float16 + if current_platform.get_device_capability()[1] < 5: + raise RuntimeError("SGLang only supports sm75 and above.") + + set_cuda_arch() + + # Prepare the model config + self.load_config = LoadConfig( + load_format=self.server_args.load_format, + download_dir=self.server_args.download_dir, + model_loader_extra_config=self.server_args.model_loader_extra_config, + ) + if self.device == "cpu": + self.model_config = adjust_config_with_unaligned_cpu_tp( + self.model_config, self.load_config, self.tp_size + ) + if self.server_args.load_format == "gguf": + monkey_patch_vllm_gguf_config() + + # Load the model + # Remove monkey_patch when linear.py quant remove dependencies with vllm + monkey_patch_vllm_parallel_state() + monkey_patch_isinstance_for_vllm_base_layer() + + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) + monkey_patch_vllm_parallel_state(reverse=True) + monkey_patch_isinstance_for_vllm_base_layer(reverse=True) + + get_offloader().post_init() + + if self.server_args.kv_cache_dtype == "fp8_e4m3": + if self.server_args.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.server_args.quantization_param_path + ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.server_args.quantization_param_path, + ) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) + + # Parse other args + self.sliding_window_size = None + if hasattr(self.model, "get_attention_sliding_window_size"): + self.sliding_window_size = self.model.get_attention_sliding_window_size() + elif self.model_config.attention_chunk_size is not None: + self.sliding_window_size = self.model_config.attention_chunk_size + logger.info( + f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}" + ) + + self.dtype = self.model_config.dtype + + after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id) + self.weight_load_mem_usage = before_avail_memory - after_avail_memory + logger.info( + f"Load weight end. " + f"type={type(self.model).__name__}, " + f"dtype={self.dtype}, " + f"avail mem={after_avail_memory:.2f} GB, " + f"mem usage={self.weight_load_mem_usage:.2f} GB." + ) + + # Handle the case where some ranks do not finish loading. + try: + dist.monitored_barrier( + group=get_tp_group().cpu_group, + timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), + wait_all_ranks=True, + ) + except RuntimeError: + raise ValueError( + f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." + ) from None + + def setup_collective_group(self, comm_plan, backend, rank_in_cluster): + self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {}) + rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster, + rank_in_worker=dist.get_rank()) + if rank is None: + logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}") + return True, "Succeeded to setup_collective_group." + + group_name = comm_plan_args["group_name"] + master_addr = comm_plan_args["master_addr"] + master_port = comm_plan_args["master_port"] + world_size = len(comm_plan_args["tgt_devices"]) + 1 + src_pp_rank = comm_plan_args["src_pp_rank"] + collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name, + master_addr=master_addr, master_port=master_port) + # A small all_reduce for warmup. + collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name) + self.model_update_comm_plan[src_pp_rank] = dict(rank=rank, + world_size=world_size, + src_pp_rank=src_pp_rank, + group_name=group_name, + comm_plan=comm_plan, + comm_plan_args=comm_plan_args) + logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}") + return True, "Succeeded to setup_collective_group." + + def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): + if src_pp_rank not in self.model_update_comm_plan: + return True, "Succeeded to broadcast_bucket." + + comm_plan = self.model_update_comm_plan[src_pp_rank] + buffer = torch.empty(bucket_size, dtype=torch.int8, device=current_platform.device_type) + collective.broadcast(tensor=buffer, src_rank=0, group_name=comm_plan["group_name"]) + self.update_parameter_in_bucket(meta_infos, buffer, [dist.get_rank()]) + return True, "Succeeded to broadcast_bucket." + + def broadcast_parameter(self, src_pp_rank, dtype, shape, parameter_name): + if src_pp_rank not in self.model_update_comm_plan: + return True, "Succeeded to broadcast_parameter." + comm_plan = self.model_update_comm_plan[src_pp_rank] + weight = torch.empty(shape, dtype=dtype, device=current_platform.device_type) + collective.broadcast(tensor=weight, src_rank=0, group_name=comm_plan["group_name"]) + self.update_parameter(parameter_name, weight, [dist.get_rank()]) + return True, "Succeeded to broadcast_parameter." + + def update_parameter(self, parameter_name, weight, ranks_in_worker): + if dist.get_rank() not in ranks_in_worker: + return True, "Succeeded to update_parameter." + self.model.load_weights([(parameter_name, weight)]) + del weight + return True, "Succeeded to update_parameter." + + def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): + if dist.get_rank() not in ranks_in_worker: + return True, "Succeeded to update_parameter_in_bucket." + from mcore_adapter.models.converter.convert_utils import RecvBucketManager + self.recv_manager = getattr(self, "recv_manager", RecvBucketManager()) + named_params = self.recv_manager.process_bucket(meta_infos, buffer) + del buffer + self.model.load_weights([(name, weight) for name, weight in named_params.items()]) + return True, "Succeeded to update_parameter_in_bucket." \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/scheduler.py b/roll/third_party/sglang/v054_patch/scheduler.py new file mode 100644 index 000000000..bc8eccaba --- /dev/null +++ b/roll/third_party/sglang/v054_patch/scheduler.py @@ -0,0 +1,98 @@ +import torch +from roll.platforms import current_platform + + +from sglang.srt.managers.io_struct import ( + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, +) +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.managers.scheduler import Scheduler + +from sglang.srt.managers.scheduler_update_weights_mixin import _import_static_state, _export_static_state + + +from roll.third_party.sglang.io_struct import ( + SetupCollectiveGroupReqInput, + BroadcastBucketReqInput, + BroadcastParameterReqInput, + UpdateParameterInBucketReqInput, + UpdateParameterReqInput, + SetupCollectiveGroupReqOutput, + BroadcastBucketReqOutput, + BroadcastParameterReqOutput, + UpdateParameterInBucketReqOutput, + UpdateParameterReqOutput, +) + +class SchedulerSA(Scheduler): + def __init__(self, *args, **kwargs): + import sys + from roll.third_party.sglang.v054_patch.tp_worker import TpModelWorkerSA + sys.modules['sglang.srt.managers.scheduler'].__dict__['TpModelWorker'] = TpModelWorkerSA + super().__init__(*args, **kwargs) + func_map_patch = [(SetupCollectiveGroupReqInput, self.setup_collective_group), + (BroadcastBucketReqInput, self.broadcast_bucket), + (BroadcastParameterReqInput, self.broadcast_parameter), + (UpdateParameterInBucketReqInput, self.update_parameter_in_bucket), + (UpdateParameterReqInput, self.update_parameter)] + self._request_dispatcher._mapping += func_map_patch + + def setup_collective_group(self, recv_req: SetupCollectiveGroupReqInput): + success, message = self.tp_worker.setup_collective_group(recv_req) + return SetupCollectiveGroupReqOutput(success, message) + + def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): + self.stashed_model_static_state = _export_static_state( + self.tp_worker.model_runner.model + ) + self.tp_worker.model_runner.model.to('cpu') + self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE) + self.flush_cache() + return ReleaseMemoryOccupationReqOutput() + + def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): + self.tp_worker.model_runner.model.to(current_platform.current_device()) + self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE) + + # gc.collect() + # torch.cuda.empty_cache() + # self.tp_worker.model_runner.model.to(current_platform.current_device()) + _import_static_state( + self.tp_worker.model_runner.model, self.stashed_model_static_state + ) + del self.stashed_model_static_state + + self.tp_worker.model_runner.init_cublas() + self.tp_worker.model_runner.init_attention_backend() + from sglang.srt.model_executor.cuda_graph_runner import set_global_graph_memory_pool + set_global_graph_memory_pool(None) + self.tp_worker.model_runner.init_device_graphs() + + return ResumeMemoryOccupationReqOutput() + + def broadcast_bucket(self, recv_req: BroadcastBucketReqInput): + success, message = self.tp_worker.broadcast_bucket(recv_req) + return BroadcastBucketReqOutput(success, message) + + def broadcast_parameter(self, recv_req: BroadcastParameterReqInput): + success, message = self.tp_worker.broadcast_parameter(recv_req) + return BroadcastParameterReqOutput(success, message) + + def update_parameter(self, recv_req: UpdateParameterReqInput): + success, message = self.tp_worker.update_parameter(recv_req) + return UpdateParameterReqOutput(success, message) + + def update_parameter_in_bucket(self, recv_req: UpdateParameterInBucketReqInput): + success, message = self.tp_worker.update_parameter_in_bucket(recv_req) + return UpdateParameterInBucketReqOutput(success, message) + + +def run_scheduler_process(*args, **kwargs): + import sys + sys.modules['sglang.srt.managers.scheduler'].__dict__['Scheduler'] = SchedulerSA + from sglang.srt.managers.scheduler import run_scheduler_process + return run_scheduler_process(*args, **kwargs) + diff --git a/roll/third_party/sglang/v054_patch/tokenizer_manager.py b/roll/third_party/sglang/v054_patch/tokenizer_manager.py new file mode 100644 index 000000000..c3708bbea --- /dev/null +++ b/roll/third_party/sglang/v054_patch/tokenizer_manager.py @@ -0,0 +1,127 @@ +import os +from typing import Optional, Tuple +import fastapi + +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator + +from roll.third_party.sglang.io_struct import ( + SetupCollectiveGroupReqInput, + BroadcastBucketReqInput, + BroadcastParameterReqInput, + UpdateParameterInBucketReqInput, + UpdateParameterReqInput, + SetupCollectiveGroupReqOutput, + BroadcastBucketReqOutput, + BroadcastParameterReqOutput, + UpdateParameterInBucketReqOutput, + UpdateParameterReqOutput, +) + +class TokenizerManagerSA(TokenizerManager): + def __init__( + self, + server_args: ServerArgs, + port_args: PortArgs, + ): + super().__init__(server_args=server_args, port_args=port_args) + + self.setup_collective_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.broadcast_bucket_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.broadcast_parameter_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_parameter_in_bucket_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_parameter_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + + communicator_patch = [( + SetupCollectiveGroupReqOutput, + self.setup_collective_group_communicator.handle_recv, + ), + ( + BroadcastBucketReqOutput, + self.broadcast_bucket_communicator.handle_recv, + ), + ( + BroadcastParameterReqOutput, + self.broadcast_parameter_communicator.handle_recv, + ), + ( + UpdateParameterInBucketReqOutput, + self.update_parameter_in_bucket_communicator.handle_recv, + ), + ( + UpdateParameterReqOutput, + self.update_parameter_communicator.handle_recv, + )] + + self._result_dispatcher._mapping += communicator_patch + + async def setup_collective_group( + self, + obj: SetupCollectiveGroupReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.setup_collective_group_communicator(obj))[0] + return result.success, result.message + + async def broadcast_bucket( + self, + obj: BroadcastBucketReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.broadcast_bucket_communicator(obj))[0] + return result.success, result.message + + async def broadcast_parameter( + self, + obj: BroadcastParameterReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.broadcast_parameter_communicator(obj))[0] + return result.success, result.message + + async def update_parameter( + self, + obj: UpdateParameterReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.update_parameter_communicator(obj))[0] + return result.success, result.message + + async def update_parameter_in_bucket( + self, + obj: UpdateParameterInBucketReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be 1 for init parameter update group" + result = (await self.update_parameter_in_bucket_communicator(obj))[0] + return result.success, result.message \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/tp_worker.py b/roll/third_party/sglang/v054_patch/tp_worker.py new file mode 100644 index 000000000..eee8a8075 --- /dev/null +++ b/roll/third_party/sglang/v054_patch/tp_worker.py @@ -0,0 +1,58 @@ +from sglang.srt.managers.tp_worker import TpModelWorker + + +from roll.third_party.sglang.io_struct import ( + SetupCollectiveGroupReqInput, + BroadcastBucketReqInput, + BroadcastParameterReqInput, + UpdateParameterInBucketReqInput, + UpdateParameterReqInput, +) + +class TpModelWorkerSA(TpModelWorker): + def __init__(self, *args, **kwargs): + import sys + from roll.third_party.sglang.v054_patch.model_runner import ModelRunnerSA + sys.modules['sglang.srt.managers.tp_worker'].__dict__['ModelRunner'] = ModelRunnerSA + super().__init__(*args, **kwargs) + + def setup_collective_group(self, recv_req: SetupCollectiveGroupReqInput): + success, message = self.model_runner.setup_collective_group( + recv_req.comm_plan, + recv_req.backend, + recv_req.rank_in_cluster, + ) + return success, message + + def broadcast_bucket(self, recv_req: BroadcastBucketReqInput): + success, message = self.model_runner.broadcast_bucket( + recv_req.src_pp_rank, + recv_req.meta_infos, + recv_req.bucket_size, + ) + return success, message + + def broadcast_parameter(self, recv_req: BroadcastParameterReqInput): + success, message = self.model_runner.broadcast_parameter( + recv_req.src_pp_rank, + recv_req.dtype, + recv_req.shape, + recv_req.parameter_name, + ) + return success, message + + def update_parameter(self, recv_req: UpdateParameterReqInput): + success, message = self.model_runner.update_parameter( + recv_req.parameter_name, + recv_req.weight, + recv_req.ranks_in_worker, + ) + return success, message + + def update_parameter_in_bucket(self, recv_req: UpdateParameterInBucketReqInput): + success, message = self.model_runner.update_parameter_in_bucket( + recv_req.meta_infos, + recv_req.buffer, + recv_req.ranks_in_worker, + ) + return success, message \ No newline at end of file From bc0cd9d74d9d0990fc1545667c91ffcb8d9b0168 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Tue, 11 Nov 2025 21:11:12 +0800 Subject: [PATCH 32/58] (feat): add enable_reference option. --- roll/configs/base_config.py | 5 +++ roll/pipeline/agentic/agentic_pipeline.py | 41 +++++++++++------- roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py | 39 +++++++++-------- roll/pipeline/rlvr/rlvr_pipeline.py | 45 +++++++++++--------- roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 32 ++++++++------ 5 files changed, 97 insertions(+), 65 deletions(-) diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 8d9871e35..ca4d1f049 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -382,6 +382,9 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) + enable_reference: bool = field( + default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} + ) def __post_init__(self): super().__post_init__() @@ -406,6 +409,8 @@ def __post_init__(self): self.actor_train.name = "actor_train" self.reference.name = "reference" self.critic.name = "critic" + if self.use_kl_loss or self.init_kl_coef > 0: + self.enable_reference = True def set_max_steps(self, max_steps: int): actor_backward_batch_size = ( diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 923658c07..62cf7b56d 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -62,14 +62,17 @@ def __init__(self, pipeline_config: AgenticConfig): resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_infer, ) - self.reference: Any = Cluster( - name=self.pipeline_config.reference.name, - worker_cls=self.pipeline_config.reference.worker_cls, - resource_manager=self.resource_manager, - worker_config=self.pipeline_config.reference, - ) + download_clusters = [self.actor_train, self.actor_infer] + + if self.pipeline_config.enable_reference: + self.reference: Any = Cluster( + name=self.pipeline_config.reference.name, + worker_cls=self.pipeline_config.reference.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.reference, + ) + download_clusters.append(self.reference) - download_clusters = [self.actor_train, self.actor_infer, self.reference] if self.pipeline_config.adv_estimator == "gae": self.critic: Any = Cluster( name=self.pipeline_config.critic.name, @@ -112,7 +115,8 @@ def __init__(self, pipeline_config: AgenticConfig): self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=True) - refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) + if self.pipeline_config.enable_reference: + refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) self.set_model_update_pair( src_cluster=self.actor_train, tgt_cluster=self.actor_infer, @@ -189,13 +193,14 @@ def run(self): metrics.update(reduce_metrics(batch.meta_info.pop("metrics", {}))) with Timer(name="cal_ref_log_probs", logger=None) as cal_timer: - ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs(batch, blocking=False) - ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) - ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") - batch = batch.union(ref_log_probs) - avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) - metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) - metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) + if self.pipeline_config.enable_reference: + ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs(batch, blocking=False) + ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) + avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) + metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) metrics["time/step_ref_log_probs_values_reward"] = cal_timer.last with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: @@ -213,6 +218,12 @@ def run(self): avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) + # Mock ref_log_probs using old_log_probs if reference cluster is disabled + if not self.pipeline_config.enable_reference: + batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() + avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) + agg_entropy = agg_loss( loss_mat=old_log_probs.batch["entropy"], loss_mask=batch.batch["response_mask"][:, 1:], diff --git a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py index bd812c1d0..21e919907 100644 --- a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py @@ -222,7 +222,7 @@ def __init__(self, pipeline_config: RLVRConfig): worker_config=self.pipeline_config.actor_infer, ) # use unwrapped model as reference for lora training - if not self.is_lora: + if not self.is_lora and self.pipeline_config.enable_reference: self.reference: Any = Cluster( name=self.pipeline_config.reference.name, worker_cls=self.pipeline_config.reference.worker_cls, @@ -264,7 +264,7 @@ def __init__(self, pipeline_config: RLVRConfig): ray.get(refs) refs = [] - if not self.is_lora: + if not self.is_lora and self.pipeline_config.enable_reference: refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=False)) refs.extend(self.reward.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) @@ -360,25 +360,24 @@ def run(self): ) with Timer(name="cal_ref_log_probs_reward", logger=None) as cal_timer: - if self.is_lora: - batch.meta_info["disable_adapter"] = True - batch.meta_info["is_offload_states"] = False - ref_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( - batch, blocking=False - ) - else: - ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs( - batch, blocking=False - ) + if self.pipeline_config.enable_reference: + if self.is_lora: + batch.meta_info["disable_adapter"] = True + batch.meta_info["is_offload_states"] = False + ref_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( + batch, blocking=False + ) + else: + ref_log_probs_refs: List[ray.ObjectRef] = self.reference.compute_log_probs( + batch, blocking=False + ) + ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) + metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) rewards_refs: List[ray.ObjectRef] = self.reward.compute_rewards(batch, blocking=False) - - ref_log_probs = DataProto.materialize_concat(data_refs=ref_log_probs_refs) rewards = DataProto.materialize_concat(data_refs=rewards_refs) - - metrics.update(reduce_metrics(ref_log_probs.meta_info.pop("metrics", {}))) metrics.update(reduce_metrics(rewards.meta_info.pop("metrics", {}))) - ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") - batch = batch.union(ref_log_probs) batch = batch.union(rewards) metrics["time/ref_log_probs_values_reward"] = cal_timer.last @@ -399,6 +398,10 @@ def run(self): batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) + + # Mock ref_log_probs using old_log_probs if reference is disabled + if not self.pipeline_config.enable_reference: + batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() metrics["time/old_log_probs"] = cal_old_logpb_timer.last diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index d4f35b2c2..5565e9989 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -216,7 +216,7 @@ def __init__(self, pipeline_config: RLVRConfig): ) download_clusters = [self.actor_train, self.actor_infer] # use unwrapped model as reference for lora training - if not self.is_lora: + if not self.is_lora and self.pipeline_config.enable_reference: self.reference: Any = Cluster( name=self.pipeline_config.reference.name, worker_cls=self.pipeline_config.reference.worker_cls, @@ -310,7 +310,7 @@ def __init__(self, pipeline_config: RLVRConfig): refs.extend(self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) - if not self.is_lora: + if not self.is_lora and self.pipeline_config.enable_reference: refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) refs = [] @@ -540,24 +540,25 @@ def run(self): with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: - if self.is_lora: - batch.meta_info["disable_adapter"] = True - batch.meta_info["is_offload_states"] = False - ref_log_probs = self.actor_train.compute_log_probs(batch, blocking=True) - else: - if self.pipeline_config.reference.use_dynamic_batching_in_infer: - batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, - self.reference.dp_size, - self.pipeline_config.reference.max_tokens_per_microbatch_in_infer, - self.pipeline_config.reference.sequence_length_round_in_infer, - "reference/compute_log_probs", - ) - metrics_mgr.add_metrics(dynamic_batching_metrics) - ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) - metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) - ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") - batch = batch.union(ref_log_probs) + if self.pipeline_config.enable_reference: + if self.is_lora: + batch.meta_info["disable_adapter"] = True + batch.meta_info["is_offload_states"] = False + ref_log_probs = self.actor_train.compute_log_probs(batch, blocking=True) + else: + if self.pipeline_config.reference.use_dynamic_batching_in_infer: + batch, dynamic_batching_metrics = dynamic_batching_shard( + batch, + self.reference.dp_size, + self.pipeline_config.reference.max_tokens_per_microbatch_in_infer, + self.pipeline_config.reference.sequence_length_round_in_infer, + "reference/compute_log_probs", + ) + metrics_mgr.add_metrics(dynamic_batching_metrics) + ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) metrics_mgr.add_metric("time/ref_log_probs_values", cal_ref_log_probs_timer.last) with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: @@ -601,6 +602,10 @@ def run(self): batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) + + # Mock ref_log_probs using old_log_probs if reference is disabled + if not self.pipeline_config.enable_reference: + batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() metrics_mgr.add_metric("time/old_log_probs", cal_old_logpb_timer.last) # 要按domain group by处理reward diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index 0405e2b29..54df18895 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -313,13 +313,15 @@ def __init__(self, pipeline_config: RLVRConfig): resource_manager=self.resource_manager, worker_config=self.pipeline_config.actor_infer, ) - self.reference: Any = Cluster( - name=self.pipeline_config.reference.name, - worker_cls=self.pipeline_config.reference.worker_cls, - resource_manager=self.resource_manager, - worker_config=self.pipeline_config.reference, - ) - download_clusters = [self.actor_train, self.actor_infer, self.reference] + download_clusters = [self.actor_train, self.actor_infer] + if self.pipeline_config.enable_reference: + self.reference: Any = Cluster( + name=self.pipeline_config.reference.name, + worker_cls=self.pipeline_config.reference.worker_cls, + resource_manager=self.resource_manager, + worker_config=self.pipeline_config.reference, + ) + download_clusters.append(self.reference) if self.pipeline_config.adv_estimator == "gae": self.critic: Any = Cluster( name=self.pipeline_config.critic.name, @@ -434,7 +436,8 @@ def __init__(self, pipeline_config: RLVRConfig): refs.extend(self.actor_infer.initialize(pipeline_config=self.pipeline_config, blocking=False)) ray.get(refs) - refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) + if self.pipeline_config.enable_reference: + refs.extend(self.reference.initialize(pipeline_config=self.pipeline_config, blocking=True)) refs = [] for key, cluster in self.rewards.items(): refs.extend(cluster.initialize(pipeline_config=self.pipeline_config, blocking=False)) @@ -542,10 +545,11 @@ def run(self): batch.meta_info["_broadcast_non_tensor_batch"]= True with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: - ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) - metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) - ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") - batch = batch.union(ref_log_probs) + if self.pipeline_config.enable_reference: + ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) + metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) + ref_log_probs.rename(old_keys="log_probs", new_keys="ref_log_probs") + batch = batch.union(ref_log_probs) metrics_mgr.add_metric("time/ref_log_probs_values", cal_ref_log_probs_timer.last) with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: @@ -568,6 +572,10 @@ def run(self): batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) + + # Mock ref_log_probs using old_log_probs if reference is disabled + if not self.pipeline_config.enable_reference: + batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() metrics_mgr.add_metric("time/old_log_probs", cal_old_logpb_timer.last) # group by domain to process reward From 31feaf6fdd841842ef5ef852a953a194f847a112 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Wed, 12 Nov 2025 12:22:55 +0800 Subject: [PATCH 33/58] (fix): fix agentic reference. --- roll/pipeline/agentic/agentic_pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 62cf7b56d..2ebd94789 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -401,7 +401,10 @@ def adjust_batch(self, data: DataProto, mode="copy") -> DataProto: """ actor_train_train_bsz = self.pipeline_config.actor_train.training_args.per_device_train_batch_size * self.pipeline_config.actor_train.training_args.gradient_accumulation_steps * self.actor_train.dp_size actor_train_infer_bsz = self.pipeline_config.actor_train.infer_batch_size * self.actor_train.dp_size - ref_infer_bsz = self.pipeline_config.reference.infer_batch_size * self.reference.dp_size + + ref_infer_bsz = 1 + if hasattr(self, "reference"): + ref_infer_bsz = self.pipeline_config.reference.infer_batch_size * self.reference.dp_size critic_train_bsz = 1 critic_infer_bsz = 1 if self.pipeline_config.adv_estimator == "gae": From 7c2985863fa6e58e245082d67bc56f4b70a784db Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:16:23 +0800 Subject: [PATCH 34/58] (feat): add flash-linear-attention. --- requirements_torch280_sglang.txt | 5 ++++- requirements_torch280_vllm.txt | 4 ++++ roll/distributed/strategy/sglang_strategy.py | 6 ++++- roll/pipeline/rlvr/rlvr_rollout_pipeline.py | 23 ++------------------ 4 files changed, 15 insertions(+), 23 deletions(-) diff --git a/requirements_torch280_sglang.txt b/requirements_torch280_sglang.txt index f42a343f4..e174bc6ac 100644 --- a/requirements_torch280_sglang.txt +++ b/requirements_torch280_sglang.txt @@ -5,4 +5,7 @@ torchvision==0.23.0.* torchaudio==2.8.0.* deepspeed==0.16.4 -sglang[srt,torch-memory-saver]==0.5.2 \ No newline at end of file +sglang[srt,torch-memory-saver]==0.5.2 + +# for GDN , eg: Qwen3Next +flash-linear-attention \ No newline at end of file diff --git a/requirements_torch280_vllm.txt b/requirements_torch280_vllm.txt index 0eec024f2..d01319cf6 100644 --- a/requirements_torch280_vllm.txt +++ b/requirements_torch280_vllm.txt @@ -11,3 +11,7 @@ flash-attn #todo upgrade docker image vllm==0.10.2 + + +# for GDN , eg: Qwen3Next +flash-linear-attention \ No newline at end of file diff --git a/roll/distributed/strategy/sglang_strategy.py b/roll/distributed/strategy/sglang_strategy.py index 38784a1b1..cdd46ce4d 100644 --- a/roll/distributed/strategy/sglang_strategy.py +++ b/roll/distributed/strategy/sglang_strategy.py @@ -10,7 +10,6 @@ import torch import torch.distributed as dist -from sglang.srt.hf_transformers_utils import get_tokenizer from torch.nn.utils.rnn import pad_sequence from transformers import set_seed @@ -24,6 +23,11 @@ from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform +try: + from sglang.srt.hf_transformers_utils import get_tokenizer +except: + from sglang.srt.utils.hf_transformers_utils import get_tokenizer + logger = get_logger() diff --git a/roll/pipeline/rlvr/rlvr_rollout_pipeline.py b/roll/pipeline/rlvr/rlvr_rollout_pipeline.py index d91b835cd..43d505a12 100644 --- a/roll/pipeline/rlvr/rlvr_rollout_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_rollout_pipeline.py @@ -1,9 +1,5 @@ import copy import json -import math -import os -import time -from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -12,13 +8,9 @@ import torch from codetiming import Timer from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy -from ray.util.timer import _Timer -from roll.configs import GeneratingArguments -from roll.datasets.chat_template import get_chat_template from roll.datasets.collator import DataCollatorWithPaddingForPaddedKeys from roll.distributed.executor.cluster import Cluster -from roll.distributed.scheduler.async_generate_scheduler import AsyncDynamicSamplingScheduler from roll.distributed.scheduler.generate_scheduler import DynamicSamplingScheduler from roll.distributed.scheduler.protocol import DataProto from roll.models.model_providers import default_tokenizer_provider @@ -26,17 +18,6 @@ from roll.pipeline.rlvr.rlvr_config import RLVRConfig from roll.pipeline.rlvr.rlvr_pipeline import RLVRPipeline, get_encode_function, preprocess_dataset, \ update_dataset_domain -from roll.pipeline.rlvr.utils import dump_rollout_to_specific_path -from roll.utils.functionals import ( - RunningMoments, - agg_loss, - compute_advantage, - compute_token_reward, - get_sample_level_mask, - reduce_metrics, - reward_postprocess, -) -from roll.utils.kl_controller import get_kl_controller from roll.utils.logging import get_logger from roll.utils.metrics.metrics_manager import MetricsManager @@ -69,13 +50,13 @@ def __init__(self, pipeline_config: RLVRConfig): # 加上format,然后转ids的func template_name = self.pipeline_config.global_template - encode_function = get_encode_function(template_name, self.tokenizer) + encode_function = get_encode_function(template_name, self.tokenizer, self.pipeline_config.actor_train.data_args) self.val_dataset = preprocess_dataset( self.val_dataset, self.pipeline_config.prompt_length, encode_function, - num_proc=1, + data_args=self.pipeline_config.actor_train.data_args, ) self.val_dataset = self.val_dataset.map( partial(update_dataset_domain, self.pipeline_config.tag_2_domain), From c001d6cf5c723f8d2525fc9a95bdab4a3be14af6 Mon Sep 17 00:00:00 2001 From: "huangju.hj" Date: Thu, 13 Nov 2025 14:04:53 +0800 Subject: [PATCH 35/58] (fix): vllm _generate_standard missing prompt_token_ids input args in vllm >0.11.0. --- .../rlvr_config_8gpus.yaml | 10 ++++++++-- roll/distributed/strategy/vllm_strategy.py | 16 ++++------------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_8gpus.yaml b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_8gpus.yaml index 9547179a9..91b1b342c 100644 --- a/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_8gpus.yaml +++ b/examples/qwen2.5-7B-rlvr_megatron/rlvr_config_8gpus.yaml @@ -247,7 +247,13 @@ rewards: data_args: template: qwen2_5 strategy_args: - strategy_name: hf_infer - strategy_config: null + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto device_mapping: list(range(6,8)) infer_batch_size: 4 \ No newline at end of file diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 1ecb925e7..21c5f6651 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -8,14 +8,14 @@ from collections import defaultdict, deque from concurrent import futures from typing import Dict, List, Optional, Union +from packaging.version import Version -import ray import torch import torch.distributed as dist from torch.nn.utils.rnn import pad_sequence from transformers import set_seed +import vllm from vllm import RequestOutput, SamplingParams -from vllm.beam_search import BeamSearchOutput from vllm.lora.request import LoRARequest from vllm.sampling_params import RequestOutputKind, BeamSearchParams from vllm.utils import random_uuid @@ -24,18 +24,10 @@ from roll.distributed.scheduler.protocol import DataProto, list_of_dict_to_dict_of_list from roll.distributed.strategy.strategy import InferenceStrategy from roll.third_party.vllm import LLM, AsyncLLM -from roll.utils.collective import collective from roll.utils.functionals import GenerateRequestType, concatenate_input_and_output, reduce_metrics from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType from roll.platforms import current_platform -try: - from vllm.inputs import TokensPrompt - high_version_vllm=True -except: - high_version_vllm=False - pass - logger = get_logger() @@ -166,11 +158,11 @@ def _generate_standard(self, batch: DataProto, generation_config) -> torch.Tenso if "multi_modal_data" in batch.non_tensor_batch: vllm_input_args["prompts"] = batch.non_tensor_batch["multi_modal_data"] else: - if high_version_vllm: + if Version(vllm.__version__) >= Version("0.11.0"): + from vllm.inputs import TokensPrompt prompt_token_ids_list=gather_unpadded_input_ids( input_ids=input_ids, attention_mask=attention_mask ) - vllm_input_args["prompts"] = [TokensPrompt(prompt_token_ids=prompt_token_ids)for prompt_token_ids in prompt_token_ids_list] else: vllm_input_args["prompt_token_ids"] = gather_unpadded_input_ids( From 85a081c99d7dab511f0ba8689315e2b8609677f5 Mon Sep 17 00:00:00 2001 From: "tianhe.lzd" Date: Thu, 13 Nov 2025 17:36:00 +0800 Subject: [PATCH 36/58] (fix): sglang 054post2 tp worker init wrong. --- .../sglang/v054_patch/model_runner.py | 96 ++++++++++++++----- .../sglang/v054_patch/scheduler.py | 2 +- 2 files changed, 72 insertions(+), 26 deletions(-) diff --git a/roll/third_party/sglang/v054_patch/model_runner.py b/roll/third_party/sglang/v054_patch/model_runner.py index ce1832d8d..12529a4c1 100644 --- a/roll/third_party/sglang/v054_patch/model_runner.py +++ b/roll/third_party/sglang/v054_patch/model_runner.py @@ -1,25 +1,31 @@ import logging -from dataclasses import dataclass import torch import torch.distributed as dist import datetime +import socket +import threading from roll.platforms import current_platform from sglang.srt.model_executor.model_runner import ModelRunner, UNBALANCED_MODEL_LOADING_TIMEOUT_S from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.load_config import LoadConfig, LoadFormat + from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp from sglang.srt.distributed import get_tp_group -from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.model_loader import get_model -from sglang.srt.offloader import get_offloader +from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( + trigger_init_weights_send_group_for_remote_instance_request, +) +from sglang.srt.debug_utils.tensor_dump_forward_hook import ( + register_forward_hook_for_model, +) +from sglang.srt.utils.offloader import get_offloader from sglang.srt.utils import ( get_available_gpu_memory, - monkey_patch_vllm_gguf_config, set_cuda_arch, ) @@ -40,43 +46,68 @@ def load_model(self): # This can reduce thread conflicts and speed up weight loading. if self.device != "cpu": torch.set_num_threads(1) - if self.device == current_platform.device_type: - if current_platform.get_device_capability()[0] < 8: + if self.device == "cuda": + if torch.cuda.get_device_capability()[0] < 8: logger.info( "Compute capability below sm80. Use float16 due to lack of bfloat16 support." ) self.server_args.dtype = "float16" self.model_config.dtype = torch.float16 - if current_platform.get_device_capability()[1] < 5: + if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") set_cuda_arch() # Prepare the model config + from sglang.srt.configs.modelopt_config import ModelOptConfig + + modelopt_config = ModelOptConfig( + quant=self.server_args.modelopt_quant, + checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path, + checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path, + export_path=self.server_args.modelopt_export_path, + quantize_and_serve=self.server_args.quantize_and_serve, + ) + self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, model_loader_extra_config=self.server_args.model_loader_extra_config, + tp_rank=self.tp_rank, + remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip, + remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port, + remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports, + modelopt_config=modelopt_config, ) if self.device == "cpu": self.model_config = adjust_config_with_unaligned_cpu_tp( self.model_config, self.load_config, self.tp_size ) - if self.server_args.load_format == "gguf": - monkey_patch_vllm_gguf_config() + + if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE: + if self.tp_rank == 0: + instance_ip = socket.gethostbyname(socket.gethostname()) + t = threading.Thread( + target=trigger_init_weights_send_group_for_remote_instance_request, + args=( + self.server_args.remote_instance_weight_loader_seed_instance_ip, + self.server_args.remote_instance_weight_loader_seed_instance_service_port, + self.server_args.remote_instance_weight_loader_send_weights_group_ports, + instance_ip, + ), + ) + t.start() # Load the model # Remove monkey_patch when linear.py quant remove dependencies with vllm monkey_patch_vllm_parallel_state() - monkey_patch_isinstance_for_vllm_base_layer() self.model = get_model( model_config=self.model_config, load_config=self.load_config, - device_config=DeviceConfig(self.device), + device_config=DeviceConfig(self.device, self.gpu_id), ) monkey_patch_vllm_parallel_state(reverse=True) - monkey_patch_isinstance_for_vllm_base_layer(reverse=True) get_offloader().post_init() @@ -124,19 +155,34 @@ def load_model(self): f"avail mem={after_avail_memory:.2f} GB, " f"mem usage={self.weight_load_mem_usage:.2f} GB." ) - - # Handle the case where some ranks do not finish loading. - try: - dist.monitored_barrier( - group=get_tp_group().cpu_group, - timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S), - wait_all_ranks=True, + if self.server_args.debug_tensor_dump_output_folder is not None: + register_forward_hook_for_model( + self.model, + self.server_args.debug_tensor_dump_output_folder, + self.server_args.debug_tensor_dump_layers, + self.tp_size, + self.tp_rank, + self.pp_rank, ) - except RuntimeError: - raise ValueError( - f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." - ) from None - + + if self.server_args.elastic_ep_backend == "mooncake": + # Mooncake does not support `monitored_barrier` + dist.barrier(group=get_tp_group().cpu_group) + else: + # Handle the case where some ranks do not finish loading. + try: + dist.monitored_barrier( + group=get_tp_group().cpu_group, + timeout=datetime.timedelta( + seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S + ), + wait_all_ranks=True, + ) + except RuntimeError: + raise ValueError( + f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node." + ) from None + def setup_collective_group(self, comm_plan, backend, rank_in_cluster): self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {}) rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster, diff --git a/roll/third_party/sglang/v054_patch/scheduler.py b/roll/third_party/sglang/v054_patch/scheduler.py index bc8eccaba..a5cb49b1d 100644 --- a/roll/third_party/sglang/v054_patch/scheduler.py +++ b/roll/third_party/sglang/v054_patch/scheduler.py @@ -31,7 +31,7 @@ class SchedulerSA(Scheduler): def __init__(self, *args, **kwargs): import sys from roll.third_party.sglang.v054_patch.tp_worker import TpModelWorkerSA - sys.modules['sglang.srt.managers.scheduler'].__dict__['TpModelWorker'] = TpModelWorkerSA + sys.modules['sglang.srt.managers.tp_worker'].__dict__['TpModelWorker'] = TpModelWorkerSA super().__init__(*args, **kwargs) func_map_patch = [(SetupCollectiveGroupReqInput, self.setup_collective_group), (BroadcastBucketReqInput, self.broadcast_bucket), From 346a406ba4196b994e35a6a6b583cb057f3f7607 Mon Sep 17 00:00:00 2001 From: "zhaohaizhou.zhz" Date: Fri, 14 Nov 2025 10:05:13 +0800 Subject: [PATCH 37/58] (fix): vllm add missing argument is_lora in function update_parameter. --- roll/third_party/vllm/vllm_0_10_0/llm.py | 4 ++-- roll/third_party/vllm/vllm_0_10_0/v1/worker.py | 4 ++-- roll/third_party/vllm/vllm_0_10_2/llm.py | 4 ++-- roll/third_party/vllm/vllm_0_10_2/v1/worker.py | 4 ++-- roll/third_party/vllm/vllm_0_11_0/llm.py | 4 ++-- roll/third_party/vllm/vllm_0_11_0/v1/worker.py | 4 ++-- roll/third_party/vllm/vllm_0_8_4/llm.py | 4 ++-- roll/third_party/vllm/vllm_0_8_4/v1/worker.py | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/roll/third_party/vllm/vllm_0_10_0/llm.py b/roll/third_party/vllm/vllm_0_10_0/llm.py index e23de923e..56aa1cfdf 100644 --- a/roll/third_party/vllm/vllm_0_10_0/llm.py +++ b/roll/third_party/vllm/vllm_0_10_0/llm.py @@ -211,13 +211,13 @@ def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): def broadcast_parameter(self, *args, **kwargs): self.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): if envs.VLLM_USE_V1: weight_dict = { "dtype": weight.dtype, "weight": weight.cpu().tolist() } - self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker)) + self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker, is_lora)) def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): if envs.VLLM_USE_V1: diff --git a/roll/third_party/vllm/vllm_0_10_0/v1/worker.py b/roll/third_party/vllm/vllm_0_10_0/v1/worker.py index 9f8ac08f2..f65f07430 100644 --- a/roll/third_party/vllm/vllm_0_10_0/v1/worker.py +++ b/roll/third_party/vllm/vllm_0_10_0/v1/worker.py @@ -22,10 +22,10 @@ def __init__(self, *args, **kwargs): self.lora_params = OrderedDict() patch_vllm_lora_manager() - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): weight_dict = weight weight = torch.tensor(weight_dict["weight"], dtype=weight_dict["dtype"]).to(current_platform.device_type) - super().update_parameter(parameter_name, weight, ranks_in_worker) + super().update_parameter(parameter_name, weight, ranks_in_worker, is_lora) def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): RecvBucketManager.dict_to_meta(meta_infos) diff --git a/roll/third_party/vllm/vllm_0_10_2/llm.py b/roll/third_party/vllm/vllm_0_10_2/llm.py index 7d37a9adc..fe38a85ad 100644 --- a/roll/third_party/vllm/vllm_0_10_2/llm.py +++ b/roll/third_party/vllm/vllm_0_10_2/llm.py @@ -263,13 +263,13 @@ def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): def broadcast_parameter(self, *args, **kwargs): self.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): if envs.VLLM_USE_V1: weight_dict = { "dtype": weight.dtype, "weight": weight.cpu().tolist() } - self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker)) + self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker, is_lora)) def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): if envs.VLLM_USE_V1: diff --git a/roll/third_party/vllm/vllm_0_10_2/v1/worker.py b/roll/third_party/vllm/vllm_0_10_2/v1/worker.py index 3d976ad82..3b7a467cc 100644 --- a/roll/third_party/vllm/vllm_0_10_2/v1/worker.py +++ b/roll/third_party/vllm/vllm_0_10_2/v1/worker.py @@ -23,10 +23,10 @@ def __init__(self, *args, **kwargs): self.lora_params = OrderedDict() patch_vllm_lora_manager() - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): weight_dict = weight weight = torch.tensor(weight_dict["weight"], dtype=weight_dict["dtype"]).to(current_platform.device_type) - super().update_parameter(parameter_name, weight, ranks_in_worker) + super().update_parameter(parameter_name, weight, ranks_in_worker, is_lora) def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): RecvBucketManager.dict_to_meta(meta_infos) diff --git a/roll/third_party/vllm/vllm_0_11_0/llm.py b/roll/third_party/vllm/vllm_0_11_0/llm.py index 8734db1a8..b5b3aa48d 100644 --- a/roll/third_party/vllm/vllm_0_11_0/llm.py +++ b/roll/third_party/vllm/vllm_0_11_0/llm.py @@ -284,13 +284,13 @@ def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): def broadcast_parameter(self, *args, **kwargs): self.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): if envs.VLLM_USE_V1: weight_dict = { "dtype": weight.dtype, "weight": weight.cpu().tolist() } - self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker)) + self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker, is_lora)) def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): if envs.VLLM_USE_V1: diff --git a/roll/third_party/vllm/vllm_0_11_0/v1/worker.py b/roll/third_party/vllm/vllm_0_11_0/v1/worker.py index 8a4bdc79c..316ea41eb 100644 --- a/roll/third_party/vllm/vllm_0_11_0/v1/worker.py +++ b/roll/third_party/vllm/vllm_0_11_0/v1/worker.py @@ -22,10 +22,10 @@ def __init__(self, *args, **kwargs): self.lora_params = OrderedDict() patch_vllm_lora_manager() - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): weight_dict = weight weight = torch.tensor(weight_dict["weight"], dtype=weight_dict["dtype"]).cuda() - super().update_parameter(parameter_name, weight, ranks_in_worker) + super().update_parameter(parameter_name, weight, ranks_in_worker, is_lora) def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): RecvBucketManager.dict_to_meta(meta_infos) diff --git a/roll/third_party/vllm/vllm_0_8_4/llm.py b/roll/third_party/vllm/vllm_0_8_4/llm.py index 3d6a835b1..0f716bb53 100644 --- a/roll/third_party/vllm/vllm_0_8_4/llm.py +++ b/roll/third_party/vllm/vllm_0_8_4/llm.py @@ -206,7 +206,7 @@ def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): def broadcast_parameter(self, *args, **kwargs): self.collective_rpc(method="broadcast_parameter", args=args, kwargs=kwargs) - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): if envs.VLLM_USE_V1: weight_dict = { "dtype": weight.dtype, @@ -214,7 +214,7 @@ def update_parameter(self, parameter_name, weight, ranks_in_worker): } else: weight_dict = weight - self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker)) + self.collective_rpc(method="update_parameter", args=(parameter_name, weight_dict, ranks_in_worker, is_lora)) def update_parameter_in_bucket(self, meta_infos, buffer, ranks_in_worker): if envs.VLLM_USE_V1: diff --git a/roll/third_party/vllm/vllm_0_8_4/v1/worker.py b/roll/third_party/vllm/vllm_0_8_4/v1/worker.py index ce7441cb9..a0e473d19 100644 --- a/roll/third_party/vllm/vllm_0_8_4/v1/worker.py +++ b/roll/third_party/vllm/vllm_0_8_4/v1/worker.py @@ -22,10 +22,10 @@ def __init__(self, *args, **kwargs): self.lora_params = OrderedDict() patch_vllm_lora_manager() - def update_parameter(self, parameter_name, weight, ranks_in_worker): + def update_parameter(self, parameter_name, weight, ranks_in_worker, is_lora): weight_dict = weight weight = torch.tensor(weight_dict["weight"], dtype=weight_dict["dtype"]).to(current_platform.device_type) - super().update_parameter(parameter_name, weight, ranks_in_worker) + super().update_parameter(parameter_name, weight, ranks_in_worker, is_lora) def broadcast_bucket(self, src_pp_rank, meta_infos, bucket_size): RecvBucketManager.dict_to_meta(meta_infos) From a98f4ce1edba86567773a34c288874650cd08d5e Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:18:01 +0800 Subject: [PATCH 38/58] (feat): update mcore_adapter. --- .../src/mcore_adapter/adapters/lora_layer.py | 4 +- .../models/converter/dist_converter.py | 129 +++++++++++------- .../models/converter/model_converter.py | 30 +++- .../src/mcore_adapter/models/model_factory.py | 15 +- .../models/qwen3_next/__init__.py | 6 - 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py index ad7630a44..88980555b 100644 --- a/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py +++ b/mcore_adapter/src/mcore_adapter/adapters/lora_layer.py @@ -31,6 +31,8 @@ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge from peft.utils import transpose +from ..platforms import current_platform + class LoraParallelLinear(MegatronModule, LoraLayer): def __init__( @@ -271,7 +273,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N base_layer = self.get_base_layer() origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device if origin_device.type == "cpu": - self.to(device=torch.cuda.current_device()) + self.to(device=current_platform.current_device()) for active_adapter in adapter_names: if active_adapter in self.lora_A.keys(): diff --git a/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py b/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py index d79ba3361..f1a7e80b0 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/dist_converter.py @@ -347,12 +347,12 @@ def _revert_column_parallel(self, weights: List["Tensor"]): return weights[0] return torch.cat(weights, dim=0) - def handle_column_parallel(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_column_parallel(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: weight = self._revert_column_parallel(weights) else: weight = self._convert_column_parallel(weights) - name = self.name_relocate(name, vp_stage=vp_stage) + name = self._name_relocate(name) return {name: weight} def _convert_row_parallel(self, weight: "Tensor"): @@ -366,12 +366,12 @@ def _revert_row_parallel(self, weights: List["Tensor"]): return weights[0] return torch.cat(weights, dim=1) - def handle_row_parallel(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_row_parallel(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: weight = self._revert_row_parallel(weights) else: weight = self._convert_row_parallel(weights) - name = self.name_relocate(name, vp_stage=vp_stage) + name = self._name_relocate(name) return {name: weight} def _convert_swiglu(self, weight: "Tensor"): @@ -390,12 +390,12 @@ def _revert_swiglu(self, weights: List["Tensor"]): weight_v = self._revert_column_parallel(weights_v) return StackedTensors([weight_w, weight_v], dim=0) - def handle_swiglu(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_swiglu(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: weight = self._revert_swiglu(weights) else: weight = self._convert_swiglu(weights) - name = self.name_relocate(name, vp_stage=vp_stage) + name = self._name_relocate(name) return {name: weight} def get_pure_name(self, name: str): @@ -412,7 +412,7 @@ def get_pure_name(self, name: str): pure_name = self.config.local_to_te_key_map[pure_name] return pure_name - def name_relocate(self, name: str, vp_stage: int, moe_index: Optional[int] = None): + def _name_relocate(self, name: str, moe_index: Optional[int] = None): pure_name = self.get_pure_name(name) if self.mca_config.transformer_impl == "local": if self.revert: # when revert to hf, convert to te name @@ -423,10 +423,7 @@ def name_relocate(self, name: str, vp_stage: int, moe_index: Optional[int] = Non moe_index = get_mca_moe_index(name) if moe_index is None else moe_index if layer_index is None: return pure_name - if self.revert: - layer_index = self.get_global_layer_index(layer_index, vp_stage=vp_stage) - else: - layer_index = self.get_local_layer_index(layer_index) + if moe_index is not None: if self.revert: if self.mca_config.moe_grouped_gemm: @@ -479,7 +476,7 @@ def get_global_layer_index(self, local_layer_index: int, vp_stage: int): global_layer_index -= 1 return global_layer_index - def handle_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: weight = weights[0] if not self.efficient_mode: @@ -494,7 +491,7 @@ def handle_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]], break else: weight = weights - name = self.name_relocate(name, vp_stage=vp_stage) + name = self._name_relocate(name) return {name: weight} def handle_grouped_duplicated(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: @@ -512,33 +509,33 @@ def handle_grouped_duplicated(self, name: str, weights: Union["Tensor", List["Te else: raise NotImplementedError() moe_index = int(extract_suffix_number(name)) - return {self.name_relocate(name, moe_index=moe_index): weight} + return {self._name_relocate(name, moe_index=moe_index): weight} - def _convert_te_grouped_column(self, name: str, weights: "Tensor", vp_stage: int): + def _convert_te_grouped_column(self, name: str, weights: "Tensor"): if self.swiglu: weights = self._convert_swiglu(weights) else: weights = self._convert_column_parallel(weights) # weights = weights.transpose(0, 1) moe_index = get_mca_moe_index(name) % self.num_layers_for_expert - relocated_name = self.name_relocate(name, vp_stage=vp_stage) + str(moe_index) + relocated_name = self._name_relocate(name) + str(moe_index) return {relocated_name: weights} - def _revert_te_grouped_column(self, name: str, weights: List["Tensor"], vp_stage: int): + def _revert_te_grouped_column(self, name: str, weights: List["Tensor"]): if self.swiglu: weight = self._revert_swiglu(weights) else: weight = self._revert_column_parallel(weights) moe_index = int(extract_suffix_number(name)) - return {self.name_relocate(name, moe_index=moe_index, vp_stage=vp_stage): weight} + return {self._name_relocate(name, moe_index=moe_index): weight} - def _convert_grouped_column(self, name: str, weights: "Tensor", vp_stage: int): + def _convert_grouped_column(self, name: str, weights: "Tensor"): if self.swiglu: weights = self._convert_swiglu(weights) else: weights = self._convert_column_parallel(weights) weights = weights.transpose(0, 1) - relocated_name = self.name_relocate(name, vp_stage=vp_stage) + relocated_name = self._name_relocate(name) moe_index = get_mca_moe_index(name) % self.num_layers_for_expert if relocated_name not in self.weights_waiting_for_convert: self.weights_waiting_for_convert[relocated_name] = {} @@ -568,35 +565,35 @@ def _revert_column(weights: List["Tensor"]): ungrouped_weights = [_revert_column(weights) for weights in ungrouped_weights] return { - self.name_relocate(name, moe_index=moe_index, vp_stage=vp_stage): weight + self._name_relocate(name, moe_index=moe_index): weight for moe_index, weight in enumerate(ungrouped_weights) } - def handle_grouped_column(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_grouped_column(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: if self.use_te_grouped_moe: - return self._revert_te_grouped_column(name, weights, vp_stage=vp_stage) - return self._revert_grouped_column(name, weights, vp_stage=vp_stage) + return self._revert_te_grouped_column(name, weights) + return self._revert_grouped_column(name, weights) else: if self.use_te_grouped_moe: - return self._convert_te_grouped_column(name, weights, vp_stage=vp_stage) - return self._convert_grouped_column(name, weights, vp_stage=vp_stage) + return self._convert_te_grouped_column(name, weights) + return self._convert_grouped_column(name, weights) - def _convert_te_grouped_row(self, name: str, weights: "Tensor", vp_stage: int): + def _convert_te_grouped_row(self, name: str, weights: "Tensor"): weights = self._convert_row_parallel(weights) moe_index = get_mca_moe_index(name) % self.num_layers_for_expert - relocated_name = self.name_relocate(name, vp_stage=vp_stage) + str(moe_index) + relocated_name = self._name_relocate(name) + str(moe_index) return {relocated_name: weights} - def _revert_te_grouped_row(self, name: str, weights: List["Tensor"], vp_stage: int): + def _revert_te_grouped_row(self, name: str, weights: List["Tensor"]): weights = self._revert_row_parallel(weights) moe_index = int(extract_suffix_number(name)) - return {self.name_relocate(name, moe_index=moe_index, vp_stage=vp_stage): weights} + return {self._name_relocate(name, moe_index=moe_index): weights} - def _convert_grouped_row(self, name: str, weights: "Tensor", vp_stage: int): + def _convert_grouped_row(self, name: str, weights: "Tensor"): weights = self._convert_row_parallel(weights) weights = weights.transpose(0, 1) - relocated_name = self.name_relocate(name, vp_stage=vp_stage) + relocated_name = self._name_relocate(name) moe_index = get_mca_moe_index(name) % self.num_layers_for_expert if relocated_name not in self.weights_waiting_for_convert: self.weights_waiting_for_convert[relocated_name] = {} @@ -607,7 +604,7 @@ def _convert_grouped_row(self, name: str, weights: "Tensor", vp_stage: int): weights = [weight[1] for weight in weights] return {relocated_name: torch.stack(weights, dim=0).view(-1, self.mca_config.hidden_size)} - def _revert_grouped_row(self, name, weights: List["Tensor"], vp_stage: int): + def _revert_grouped_row(self, name, weights: List["Tensor"]): def _revert_grouped(weight: "Tensor"): weight = weight.view(self.num_layers_for_expert, -1, self.mca_config.hidden_size) expert_weights = torch.unbind(weight, dim=0) @@ -619,19 +616,19 @@ def _revert_grouped(weight: "Tensor"): ungrouped_weights = [[weights[i] for weights in ungrouped_weights] for i in range(self.num_layers_for_expert)] ungrouped_weights = [self._revert_row_parallel(weights) for weights in ungrouped_weights] return { - self.name_relocate(name, moe_index=moe_index, vp_stage=vp_stage): weight + self._name_relocate(name, moe_index=moe_index): weight for moe_index, weight in enumerate(ungrouped_weights) } - def handle_grouped_row(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: int) -> Dict[str, "Tensor"]: + def handle_grouped_row(self, name: str, weights: Union["Tensor", List["Tensor"]]) -> Dict[str, "Tensor"]: if self.revert: if self.use_te_grouped_moe: - return self._revert_te_grouped_row(name, weights, vp_stage=vp_stage) - return self._revert_grouped_row(name, weights, vp_stage=vp_stage) + return self._revert_te_grouped_row(name, weights) + return self._revert_grouped_row(name, weights) else: if self.use_te_grouped_moe: - return self._convert_te_grouped_row(name, weights, vp_stage=vp_stage) - return self._convert_grouped_row(name, weights, vp_stage=vp_stage) + return self._convert_te_grouped_row(name, weights) + return self._convert_grouped_row(name, weights) def name_match(self, pure_name: str, patterns: list[str] | dict[str, Any]): if pure_name in patterns: @@ -672,7 +669,43 @@ def get_global_moe_index(self, name: str) -> Optional[Union[int, List[int]]]: else: return [local_to_global(i) for i in local_moe_index] - def dist_convert(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_stage: Optional[int] = None) -> Dict[str, "Tensor"]: + def preprocess_layer_index(self, name: str, vp_stage: int) -> str: + """ + Preprocess layer index for pipeline parallelism. + Converts between global and local layer indices before calling name_relocate. + """ + layer_index = get_mca_layer_index(name) + if layer_index is None: + return name + moe_index = get_mca_moe_index(name) + + if self.revert: + layer_index = self.get_global_layer_index(layer_index, vp_stage=vp_stage) + else: + layer_index = self.get_local_layer_index(layer_index) + + if name.startswith("mtp.layers."): + return add_mca_mtp_layer_prefix(remove_mca_weight_prefix(name), layer_index, moe_index) + return add_mca_layer_prefix(remove_mca_weight_prefix(name), layer_index, moe_index) + + def dist_convert( + self, + name: str, + weights: Union["Tensor", List["Tensor"]], + vp_stage: Optional[int] = None, + layer_index_preprocessed: bool = False, + ) -> Dict[str, "Tensor"]: + """ + Convert weights for distributed parallelism. + + Args: + name: Weight name + weights: Weight tensor(s) + vp_stage: Virtual pipeline stage + layer_index_preprocessed: If True, the name's layer index has already been preprocessed + for pipeline parallelism by the caller. If False (default), DistConverter will + handle the layer index conversion between global and local indices. + """ if vp_stage is None: vp_stage = self.virtual_pipeline_model_parallel_rank if ( @@ -687,23 +720,27 @@ def dist_convert(self, name: str, weights: Union["Tensor", List["Tensor"]], vp_s if not self.is_on_this_rank(name, vp_stage=vp_stage): return None + + if not layer_index_preprocessed: + name = self.preprocess_layer_index(name, vp_stage) + pure_name = self.get_pure_name(name) if pure_name.endswith(".bias"): pure_name = pure_name.replace(".bias", ".weight") if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_duplicated_weights): return self.handle_grouped_duplicated(name, weights) if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_column_weights): - return self.handle_grouped_column(name, weights, vp_stage=vp_stage) + return self.handle_grouped_column(name, weights) if self.mca_config.moe_grouped_gemm and self.name_match(pure_name, self.config.grouped_row_weights): - return self.handle_grouped_row(name, weights, vp_stage=vp_stage) + return self.handle_grouped_row(name, weights) if self.swiglu and self.name_match(pure_name, self.config.swiglu_weights): - return self.handle_swiglu(name, weights, vp_stage=vp_stage) + return self.handle_swiglu(name, weights) if self.name_match(pure_name, self.config.duplicated_weights): - return self.handle_duplicated(name, weights, vp_stage=vp_stage) + return self.handle_duplicated(name, weights) if self.name_match(pure_name, self.config.column_parallel_weights): - return self.handle_column_parallel(name, weights, vp_stage=vp_stage) + return self.handle_column_parallel(name, weights) if self.name_match(pure_name, self.config.row_parallel_weights): - return self.handle_row_parallel(name, weights, vp_stage=vp_stage) + return self.handle_row_parallel(name, weights) raise ValueError(f"name: {name}, pure_name: {pure_name}, config {self.config} swiglu: {self.swiglu}") def is_tensor_parallel_dup_weight(self, name: str) -> bool: diff --git a/mcore_adapter/src/mcore_adapter/models/converter/model_converter.py b/mcore_adapter/src/mcore_adapter/models/converter/model_converter.py index 314926b1c..d9b302777 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/model_converter.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/model_converter.py @@ -60,9 +60,12 @@ def __init__( self.verbose = verbose self.template = get_template(mca_config.hf_model_type) self.template.set_mca_config_for_ops(self.mca_config) - tensor_model_parallel_rank = tensor_model_parallel_rank or mpu.get_tensor_model_parallel_rank() - pipeline_model_parallel_rank = pipeline_model_parallel_rank or mpu.get_pipeline_model_parallel_rank() - expert_model_parallel_rank = expert_model_parallel_rank or mpu.get_expert_model_parallel_rank() + if tensor_model_parallel_rank is None: + tensor_model_parallel_rank = mpu.get_tensor_model_parallel_rank() + if pipeline_model_parallel_rank is None: + pipeline_model_parallel_rank = mpu.get_pipeline_model_parallel_rank() + if expert_model_parallel_rank is None: + expert_model_parallel_rank = mpu.get_expert_model_parallel_rank() self.dist_converter = DistConverter( self.mca_config, tensor_model_parallel_rank=tensor_model_parallel_rank, @@ -152,13 +155,30 @@ def _mca_named_params_with_vp_stage(self, models): for mca_name, weight in sorted(mca_state_dict.items()): yield vp_stage, mca_name, weight - def convert_to_hf(self, mca_state_dict: Dict[str, list["Tensor"]], vp_stage: Optional[int] = None) -> Dict[str, "Tensor"]: + def convert_to_hf( + self, + mca_state_dict: Dict[str, list["Tensor"]], + vp_stage: Optional[int] = None, + layer_index_preprocessed: bool = False, + ) -> Dict[str, "Tensor"]: + """ + Convert Mca state dict to HuggingFace format. + + Args: + mca_state_dict: Dictionary of mca weight names to tensor lists + vp_stage: Virtual pipeline stage + layer_index_preprocessed: If True, the weight names' layer indices have already been + preprocessed for pipeline parallelism by the caller. If False (default), + DistConverter will handle the layer index conversion between global and local indices. + """ if vp_stage is None: vp_stage = mpu.get_virtual_pipeline_model_parallel_rank() hf_state_dict = {} for mca_name, weights in mca_state_dict.items(): - merged_named_weights = self.dist_converter.dist_convert(mca_name, weights, vp_stage=vp_stage) + merged_named_weights = self.dist_converter.dist_convert( + mca_name, weights, vp_stage=vp_stage, layer_index_preprocessed=layer_index_preprocessed + ) if merged_named_weights is None: continue converted = {} diff --git a/mcore_adapter/src/mcore_adapter/models/model_factory.py b/mcore_adapter/src/mcore_adapter/models/model_factory.py index 77975b9f9..f6792a7da 100644 --- a/mcore_adapter/src/mcore_adapter/models/model_factory.py +++ b/mcore_adapter/src/mcore_adapter/models/model_factory.py @@ -213,20 +213,7 @@ def from_pretrained( def save_pretrained(self, save_directory: str, state_dict=None): os.makedirs(save_directory, exist_ok=True) - if state_dict is None: - new_state_dict = {} - state_dict_model = self.state_dict_for_save_checkpoint() - for n, p in self.named_parameters(): - if not p.requires_grad: - continue - if n in state_dict_model: - new_state_dict[n] = state_dict_model[n] - key = n.replace('.weight', '._extra_state') - if key.endswith('._extra_state0'): - key = key.replace('._extra_state0', '._extra_state') - if key in state_dict_model: - new_state_dict[key] = state_dict_model[key] - state_dict = {"model": new_state_dict} + state_dict = state_dict if state_dict is not None else {"model": self.state_dict_for_save_checkpoint()} save_config_and_state_dict(save_directory, self.config, state_dict) def get_batch_on_this_cp_rank(self, batch: Dict[str, "torch.Tensor"], dim3_keys: List[str] = ["attention_mask"]): diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_next/__init__.py b/mcore_adapter/src/mcore_adapter/models/qwen3_next/__init__.py index f06e4f0af..b9ef623f6 100644 --- a/mcore_adapter/src/mcore_adapter/models/qwen3_next/__init__.py +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_next/__init__.py @@ -36,12 +36,6 @@ def _mca_to_hf(self, weights): @dataclass class NextQKVConverOp(QKVConverOp): """query weight used for calculating query_states and gate""" - - def __post_init__(self): - super().__post_init__() - assert len(self.hf_names) == 3, f"QKVConverOp only support three hf_names {self.hf_names}" - assert len(self.mca_names) == 1, f"QKVConverOp only support one mca_name {self.mca_names}" - def _hf_to_mca(self, weights): q_weight, k_weight, v_weight = weights nh = self.mca_config.num_attention_heads From d5f07c9e5dece9262f2a251d5b42b6b7bd308887 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:18:36 +0800 Subject: [PATCH 39/58] (fix): fix get_cached_module_file. --- roll/models/model_providers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/roll/models/model_providers.py b/roll/models/model_providers.py index 54ecfc6db..939badc92 100644 --- a/roll/models/model_providers.py +++ b/roll/models/model_providers.py @@ -42,7 +42,11 @@ def prepare_automap_files(model_path: str): python_files.append(file_name) with file_lock_context(model_path): for file_name in python_files: - get_cached_module_file(model_path, file_name) + try: + get_cached_module_file(model_path, file_name) + except Exception: + # if it's a needed file, will raise an exception when calling from_pretrained + pass def default_tokenizer_provider(model_args: "ModelArguments", model_name_or_path: str=None): From 526646877dc06ff694f9188fd15be0c401b1dd76 Mon Sep 17 00:00:00 2001 From: "fengjingxuan.fjx" Date: Mon, 17 Nov 2025 19:02:30 +0800 Subject: [PATCH 40/58] (fix): fix bugs with metrics recording in the DPO pipeline. --- roll/pipeline/dpo/actor_worker.py | 5 ----- roll/pipeline/dpo/dpo_pipeline.py | 24 ++++++++++++++---------- roll/utils/metrics/metrics_manager.py | 4 ++-- 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/roll/pipeline/dpo/actor_worker.py b/roll/pipeline/dpo/actor_worker.py index 3926543f0..b97641d4e 100644 --- a/roll/pipeline/dpo/actor_worker.py +++ b/roll/pipeline/dpo/actor_worker.py @@ -98,11 +98,6 @@ def train_step(self, data: DataProto): append_to_dict(metrics, pg_metrics) metrics["actor/lr"] = self.strategy.scheduler.get_last_lr()[0] - metrics["actor/loss"] = np.mean(metrics["actor/loss"]) - metrics["actor/acc"] = np.mean(metrics["actor/acc"]) - metrics["actor/chosen_reward"] = np.mean(metrics["actor/chosen_reward"]) - metrics["actor/reject_reward"] = np.mean(metrics["actor/reject_reward"]) - metrics["actor/grad_norm"] = np.mean(metrics.pop("actor_train/grad_norm")) data.to("cpu") output = DataProto(meta_info={"metrics": metrics}) diff --git a/roll/pipeline/dpo/dpo_pipeline.py b/roll/pipeline/dpo/dpo_pipeline.py index b25a1d604..b5450c7f0 100644 --- a/roll/pipeline/dpo/dpo_pipeline.py +++ b/roll/pipeline/dpo/dpo_pipeline.py @@ -20,6 +20,7 @@ from roll.pipeline.dpo.actor_worker import get_logps, loss_fn from roll.pipeline.dpo.dpo_config import DPOConfig from roll.utils.logging import get_logger +from roll.utils.metrics.metrics_manager import MetricsManager logger = get_logger() @@ -175,7 +176,7 @@ def __init__(self, pipeline_config: DPOConfig): @torch.no_grad() def run(self): global_step = 0 - metrics = {} + metrics_mgr = MetricsManager() for epoch in range(int(self.pipeline_config.actor_train.training_args.num_train_epochs)): logger.info(f"epoch {epoch} start...") @@ -185,13 +186,13 @@ def run(self): continue logger.info(f"pipeline step {global_step} start...") - metrics.clear() + metrics_mgr.clear_metrics() if self.val_dataset and global_step % self.pipeline_config.eval_steps == 0: with Timer(name="val_step", logger=None) as val_step_timer: val_metrics = self.val() - metrics.update(val_metrics) - metrics["time/val_step"] = val_step_timer.last + metrics_mgr.add_reduced_metrics(val_metrics) + metrics_mgr.add_metric("time/val_step", val_step_timer.last) with Timer(name="step_total", logger=None) as step_total_timer: batch_dict: Dict @@ -200,18 +201,21 @@ def run(self): with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) - metrics.update(ref_log_probs.meta_info.pop("metrics", {})) + metrics_mgr.add_reduced_metrics(ref_log_probs.meta_info.pop("metrics", {})) ref_log_probs.rename(old_keys="log_probs", new_keys="reference_log_probs") batch = batch.union(ref_log_probs) - metrics["time/cal_ref_log_probs"] = cal_ref_log_probs_timer.last + metrics_mgr.add_metric("time/cal_ref_log_probs", cal_ref_log_probs_timer.last) with Timer(name="actor_train", logger=None) as actor_train_timer: actor_train_refs = self.actor_train.train_step(batch, blocking=False) actor_train_refs: DataProto = DataProto.materialize_concat(data_refs=actor_train_refs) - metrics.update(actor_train_refs.meta_info.pop("metrics", {})) - metrics["time/actor_train"] = actor_train_timer.last - metrics["time/step_total"] = step_total_timer.last - + metrics_mgr.add_reduced_metrics(actor_train_refs.meta_info.pop("metrics", {})) + metrics_mgr.add_metric("time/actor_train", actor_train_timer.last) + metrics_mgr.add_metric("time/step_total", step_total_timer.last) + + metrics = metrics_mgr.get_metrics() + metrics = {k: float(v) for k, v in metrics.items()} + self.state.step = global_step self.state.log_history.append(metrics) self.tracker.log(values=metrics, step=global_step) diff --git a/roll/utils/metrics/metrics_manager.py b/roll/utils/metrics/metrics_manager.py index 4acb156fb..024f2423d 100644 --- a/roll/utils/metrics/metrics_manager.py +++ b/roll/utils/metrics/metrics_manager.py @@ -23,8 +23,8 @@ def add_metric(self, name: str, value: Any) -> None: def add_metrics(self, metrics_dict: Dict[str, Any]) -> None: self.metrics.update(metrics_dict) - def add_reduced_metrics(self, metrics_dict: Dict[str, Any], prefix: str = "") -> None: - reduced = reduce_metrics(metrics_dict) + def add_reduced_metrics(self, metrics_dict: Dict[str, Any], prefix: str = "", reduce_func=np.mean) -> None: + reduced = reduce_metrics(metrics_dict, reduce_func=reduce_func) if prefix: reduced = {f"{prefix}/{k}": v for k, v in reduced.items()} self.metrics.update(reduced) From 0e47311437951aa5d8d59c580789a7fd08ea0e6e Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Mon, 17 Nov 2025 19:23:25 +0800 Subject: [PATCH 41/58] (feat): add enable_old_logprobs, opt old log probs by cache. --- roll/configs/base_config.py | 43 +++++++++++- roll/pipeline/agentic/agentic_actor_worker.py | 7 +- roll/pipeline/agentic/agentic_config.py | 3 +- roll/pipeline/agentic/agentic_pipeline.py | 32 +++++---- roll/pipeline/base_worker.py | 44 ++++++++++++- roll/pipeline/rlvr/actor_pg_worker.py | 2 +- roll/pipeline/rlvr/actor_worker.py | 8 +-- roll/pipeline/rlvr/rlvr_config.py | 3 +- roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py | 21 ++++-- roll/pipeline/rlvr/rlvr_pipeline.py | 65 ++++++++++--------- roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 30 +++++---- roll/utils/metrics/metrics_manager.py | 2 +- 12 files changed, 182 insertions(+), 78 deletions(-) diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index ca4d1f049..d30dce574 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -6,7 +6,7 @@ from typing import Dict, Literal, Optional, Union from roll.configs.worker_config import WorkerConfig, is_colocated -from roll.utils.config_utils import validate_megatron_batch_size +from roll.utils.config_utils import validate_megatron_batch_size, calculate_megatron_dp_size from roll.utils.logging import get_logger @@ -385,6 +385,8 @@ class PPOConfig(BaseConfig): enable_reference: bool = field( default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} ) + enable_old_logprobs: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) + force_disable_old_logprobs: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs"}) def __post_init__(self): super().__post_init__() @@ -410,7 +412,12 @@ def __post_init__(self): self.reference.name = "reference" self.critic.name = "critic" if self.use_kl_loss or self.init_kl_coef > 0: + logger.warning(f"use_kl_loss or init_kl_coef > 0, enable_reference = True") self.enable_reference = True + if self.force_disable_old_logprobs: + self.enable_old_logprobs = False + else: + self.set_old_logprobs_status() def set_max_steps(self, max_steps: int): actor_backward_batch_size = ( @@ -439,6 +446,40 @@ def set_max_steps(self, max_steps: int): logger.info(f"critic train max_steps without dp_size: {self.critic.training_args.max_steps}") self.max_steps = max_steps + def set_old_logprobs_status(self): + batch_size = self.rollout_batch_size * self.actor_infer.generating_args.num_return_sequences + actor_backward_batch_size = ( + self.actor_train.training_args.per_device_train_batch_size + * self.actor_train.training_args.gradient_accumulation_steps + ) + dp_size = 1 + if self.actor_train.strategy_args is not None: + if self.actor_train.strategy_args.strategy_name == "deepspeed_train": + dp_size = len(self.actor_train.device_mapping) + elif self.actor_train.strategy_args.strategy_name == "megatron_train": + strategy_config = self.actor_train.strategy_args.strategy_config + tp = strategy_config.get('tensor_model_parallel_size', 1) + pp = strategy_config.get('pipeline_model_parallel_size', 1) + cp = strategy_config.get('context_parallel_size', 1) + dp_size = calculate_megatron_dp_size(num_gpus=len(self.actor_train.device_mapping), + tensor_parallel_size=tp, + pipeline_parallel_size=pp, + context_parallel_size=cp) + + # Calculate backward steps per DP rank + backward_steps_per_rank = (batch_size // dp_size) // actor_backward_batch_size + + # Disable optimization only when multiple backward steps in single training step + # Multi-epoch training is actually a key scenario for optimization + if backward_steps_per_rank > 1: + # Multiple backward steps means model parameters change during training + # Cannot reuse cached logprobs across backward passes + self.enable_old_logprobs = True + + if self.init_kl_coef > 0: + logger.warning(f"init_kl_coef > 0, enable_old_logprobs = True") + self.enable_old_logprobs = True + @property def async_pipeline(self) -> bool: return self.async_generation_ratio > 0 diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index a4298a98a..a7f683e92 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -16,15 +16,14 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() ref_log_probs = data.batch["ref_log_probs"] - old_log_probs = data.batch["old_log_probs"] - infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) - infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs - advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) + old_log_probs = self.get_old_log_probs_with_cache(data, log_probs) + infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) + infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs if self.pipeline_config.ratio_type == "segment": raise NotImplemented(f"ratio_type: {self.pipeline_config.ratio_type} not implemented") diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index 0c2d1c3b8..affbb81dd 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -180,6 +180,7 @@ class AgenticConfig(PPOConfig): ratio_type: Literal["token", "segment"] = field(default="token", metadata={"help": "Ratio type: token or segment"}) def __post_init__(self): + self.actor_infer.generating_args.num_return_sequences = 1 super().__post_init__() # default worker_cls @@ -195,8 +196,6 @@ def __post_init__(self): self.train_env_manager.name = "train_env" self.val_env_manager.name = "val_env" - self.actor_infer.generating_args.num_return_sequences = 1 - if self.render_save_dir: self.render_save_dir = os.path.join( self.render_save_dir, self.exp_name, datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 2ebd94789..925d56fcf 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -174,6 +174,8 @@ def run(self): with Timer(name="rollout", logger=None) as rollout_timer: batch.meta_info["is_offload_states"] = True batch = ray.get(self.train_rollout_scheduler.get_batch.remote(batch, self.pipeline_config.rollout_batch_size)) + sample_uuids = [f"{traj_id}_{i}" for i, traj_id in enumerate(batch.non_tensor_batch['traj_id'])] + batch.non_tensor_batch['sample_uuid'] = np.array(sample_uuids, dtype=object) if "get_batch_return_start_time" in batch.meta_info: metrics["time/get_batch_cost_train"] = time.time() - batch.meta_info.pop("get_batch_return_start_time") actor_infer_metrics = self.actor_infer.get_metrics() @@ -204,19 +206,29 @@ def run(self): metrics["time/step_ref_log_probs_values_reward"] = cal_timer.last with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: - # TODO: use engine log_probs as old_log_probs batch.meta_info["is_offload_states"] = False - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) + if self.pipeline_config.enable_old_logprobs: + old_log_probs: DataProto = self.actor_train.compute_log_probs(batch, blocking=True) + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) + metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) + metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) + agg_entropy = agg_loss( + loss_mat=old_log_probs.batch["entropy"], + loss_mask=batch.batch["response_mask"][:, 1:], + loss_agg_mode="token-mean", + ) + metrics.update({"critic/entropy/mean": agg_entropy.item()}) + else: + batch.batch["old_log_probs"] = torch.zeros_like(batch.batch["attention_mask"][:, 1:]) + if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + if self.pipeline_config.adv_estimator == "gae": values = DataProto.materialize_concat(data_refs=values_refs) batch = batch.union(values) metrics.update(reduce_metrics(values.meta_info.pop("metrics", {}))) - batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] - avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) - metrics.update({"critic/old_log_prob/mean": avg_old_log_prob.item()}) # Mock ref_log_probs using old_log_probs if reference cluster is disabled if not self.pipeline_config.enable_reference: @@ -224,14 +236,6 @@ def run(self): avg_ref_log_prob = masked_mean(batch.batch["ref_log_probs"], batch.batch["response_mask"][:, 1:]) metrics.update({"critic/ref_log_prob/mean": avg_ref_log_prob.item()}) - agg_entropy = agg_loss( - loss_mat=old_log_probs.batch["entropy"], - loss_mask=batch.batch["response_mask"][:, 1:], - loss_agg_mode="token-mean", - ) - metrics.update({"critic/entropy/mean": agg_entropy.item()}) - - metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) metrics["time/step_old_log_probs_values"] = cal_old_logpb_timer.last # TODO 当前这个还没用处 diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 8a4aac5bb..efbfad175 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -42,6 +42,7 @@ def __init__(self, worker_config: WorkerConfig): self.server_metrics = {} self.thread_server = None self.offload_manager = None + self._logprobs_cache = {} @register(dispatch_mode=Dispatch.ONE_TO_ALL) def initialize(self, pipeline_config): @@ -112,6 +113,7 @@ def train_step(self, data: DataProto): metrics["actor/lr"] = self.strategy.scheduler.get_last_lr()[0] data.to("cpu") + self._logprobs_cache.clear() output = DataProto(meta_info={"metrics": metrics}) return output @@ -257,6 +259,46 @@ def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor): entropy = self.strategy.op_compute_entropy(logits=output_tensor, attention_mask=data.batch["response_mask"]) return log_probs, {"log_probs": log_probs.clone().detach(), "entropy": entropy.clone().detach()} + def get_old_log_probs_with_cache(self, data: DataProto, log_probs: torch.Tensor) -> torch.Tensor: + """ + Get old_log_probs with intra-step caching when enable_old_logprobs == False. + When caching is enabled, the first forward pass log_probs can be reused as old_log_probs + since they are mathematically equivalent in on-policy settings. + This method can be overridden by subclasses for custom caching behavior. + + Args: + data: DataProto containing input data and sample_uuids + log_probs: Current forward pass log_probs tensor + + Returns: + old_log_probs tensor (detached, no gradients) + """ + # Original computation path when caching is disabled + if self.pipeline_config.enable_old_logprobs or "sample_uuid" not in data.non_tensor_batch: + # When enable_old_logprobs=True, use the pre-computed old_log_probs from batch + return data.batch["old_log_probs"] + + sample_uuids = data.non_tensor_batch["sample_uuid"] + + # Check first sample_uuid for efficiency - if it exists, all likely exist + first_uuid = sample_uuids[0] + if first_uuid in self._logprobs_cache: + # All samples likely cached, retrieve all from cache + cached_old_log_probs = [] + + for sample_uuid in sample_uuids: + cached_old_log_probs.append(self._logprobs_cache[sample_uuid]) + + old_log_probs = torch.cat(cached_old_log_probs, dim=0) + else: + # Cache miss - use current log_probs as old_log_probs (mathematically equivalent in on-policy) + old_log_probs = log_probs.detach() + if self.pipeline_config.ppo_epochs > 1: + for i, sample_uuid in enumerate(sample_uuids): + self._logprobs_cache[sample_uuid] = old_log_probs[i:i+1] + + return old_log_probs + def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ loss func接口定义: @@ -266,12 +308,12 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() ref_log_probs = data.batch["ref_log_probs"] - old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) + old_log_probs = self.get_old_log_probs_with_cache(data, log_probs) ratio = (log_probs - old_log_probs).exp() diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 813e582f5..477438595 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -33,12 +33,12 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] - old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) + old_log_probs = self.get_old_log_probs_with_cache(data, log_probs) valid_samples = torch.any(final_response_mask > 0, dim=1).float() sample_weights = self.compute_sample_weights(data, response_mask) diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 47d8a6b51..19d0c66de 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -16,17 +16,15 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) - ref_log_probs = data.batch["ref_log_probs"] - old_log_probs = data.batch["old_log_probs"] - infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) - infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs - advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) + old_log_probs = self.get_old_log_probs_with_cache(data, log_probs) + infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) + infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs loss_scale =None if self.worker_config.use_dynamic_batching_in_train and self.pipeline_config.loss_agg_mode == "seq-mean-token-sum": diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index cb837102e..eb0c19fc6 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -179,6 +179,7 @@ class RLVRConfig(PPOConfig): error_max_len_threshold: int = field(default=9999999999) def __post_init__(self): + self.actor_infer.generating_args.num_return_sequences = self.num_return_sequences_in_group super().__post_init__() # default worker_cls @@ -193,8 +194,6 @@ def __post_init__(self): logger.info(f"actor_train.worker_cls: {self.actor_train.worker_cls}") - self.actor_infer.generating_args.num_return_sequences = self.num_return_sequences_in_group - self.domain_2_tag = None self.tag_2_domain = None if self.rewards is not None: diff --git a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py index 21e919907..f5078e19b 100644 --- a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py @@ -1,5 +1,6 @@ import json import os +import uuid from typing import Any, Dict, List, Optional import ray @@ -358,6 +359,7 @@ def run(self): batch.non_tensor_batch[key] = np.repeat( value, self.actor_infer.worker_config.generating_args.num_return_sequences ) + batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object) with Timer(name="cal_ref_log_probs_reward", logger=None) as cal_timer: if self.pipeline_config.enable_reference: @@ -387,18 +389,23 @@ def run(self): batch.meta_info["is_offload_states"] = False if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( - batch, blocking=False - ) - old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + + if self.pipeline_config.enable_old_logprobs: + old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( + batch, blocking=False + ) + old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) + else: + # Use zeros when optimization is enabled + batch.batch["old_log_probs"] = torch.zeros_like(batch.batch["attention_mask"][:, 1:]) + if self.pipeline_config.adv_estimator == "gae": values = DataProto.materialize_concat(data_refs=values_refs) batch = batch.union(values) metrics.update(reduce_metrics(values.meta_info.pop("metrics", {}))) - batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] - metrics.update(reduce_metrics(old_log_probs.meta_info.pop("metrics", {}))) - # Mock ref_log_probs using old_log_probs if reference is disabled if not self.pipeline_config.enable_reference: batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index 5565e9989..bd5b469cd 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -3,11 +3,13 @@ import math import os import time +import uuid from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional import datasets +import numpy as np import ray import torch from codetiming import Timer @@ -536,7 +538,7 @@ def run(self): batch = generate_output batch.meta_info["global_step"] = global_step - + batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object) with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: @@ -567,42 +569,47 @@ def run(self): batch.meta_info["is_offload_states"] = False if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: - batch, dynamic_batching_metrics = dynamic_batching_shard( - batch, - self.actor_train.dp_size, - self.pipeline_config.actor_train.max_tokens_per_microbatch_in_infer, - self.pipeline_config.actor_train.sequence_length_round_in_infer, - "actor_train/compute_log_probs", - ) - metrics_mgr.add_metrics(dynamic_batching_metrics) - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) - old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) - # Customize_logging metrics, Double check call twice - if self.pipeline_config.save_logging_board_dir: - old_log_probs_refs2: List[ray.ObjectRef] = self.actor_train.compute_log_probs( - batch, blocking=False + if self.pipeline_config.enable_old_logprobs: + if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: + batch, dynamic_batching_metrics = dynamic_batching_shard( + batch, + self.actor_train.dp_size, + self.pipeline_config.actor_train.max_tokens_per_microbatch_in_infer, + self.pipeline_config.actor_train.sequence_length_round_in_infer, + "actor_train/compute_log_probs", + ) + metrics_mgr.add_metrics(dynamic_batching_metrics) + old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) + old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + + # Customize_logging metrics, Double check call twice + if self.pipeline_config.save_logging_board_dir: + old_log_probs_refs2: List[ray.ObjectRef] = self.actor_train.compute_log_probs( + batch, blocking=False + ) + old_log_probs2 = DataProto.materialize_concat(data_refs=old_log_probs_refs2) + batch.batch["old_log_probs2"] = old_log_probs2.batch["log_probs"] + batch.batch["old_log_probs2_entropy"] = old_log_probs2.batch["entropy"] + + agg_entropy = agg_loss( + loss_mat=old_log_probs.batch["entropy"], + loss_mask=batch.batch["response_mask"][:, 1:], + loss_agg_mode="token-mean", ) - old_log_probs2 = DataProto.materialize_concat(data_refs=old_log_probs_refs2) - batch.batch["old_log_probs2"] = old_log_probs2.batch["log_probs"] - batch.batch["old_log_probs2_entropy"] = old_log_probs2.batch["entropy"] - - agg_entropy = agg_loss( - loss_mat=old_log_probs.batch["entropy"], - loss_mask=batch.batch["response_mask"][:, 1:], - loss_agg_mode="token-mean", - ) - batch.meta_info["agg_entropy"] = agg_entropy + batch.meta_info["agg_entropy"] = agg_entropy + + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) + else: + # Use zeros when optimization is enabled + batch.batch["old_log_probs"] = torch.zeros_like(batch.batch["attention_mask"][:, 1:]) if self.pipeline_config.adv_estimator == "gae": values = DataProto.materialize_concat(data_refs=values_refs) batch = batch.union(values) metrics_mgr.add_reduced_metrics(values.meta_info.pop("metrics", {})) - batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] - metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) - # Mock ref_log_probs using old_log_probs if reference is disabled if not self.pipeline_config.enable_reference: batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index 54df18895..efd99228e 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -1,10 +1,12 @@ import copy import json import os +import uuid from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import datasets +import numpy as np import PIL.Image as Image import ray import torch @@ -544,6 +546,7 @@ def run(self): # mark here to make megatron get_data_input broadcast with non_batch_tensor batch.meta_info["_broadcast_non_tensor_batch"]= True + batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object) with Timer(name="cal_ref_log_probs", logger=None) as cal_ref_log_probs_timer: if self.pipeline_config.enable_reference: ref_log_probs = self.reference.compute_log_probs(batch, blocking=True) @@ -556,23 +559,28 @@ def run(self): batch.meta_info["is_offload_states"] = False if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) - old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) - agg_entropy = agg_loss( - loss_mat=old_log_probs.batch["entropy"], - loss_mask=batch.batch["response_mask"][:, 1:], - loss_agg_mode="token-mean", - ) - batch.meta_info["agg_entropy"] = agg_entropy + + if self.pipeline_config.enable_old_logprobs: + old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) + old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) + agg_entropy = agg_loss( + loss_mat=old_log_probs.batch["entropy"], + loss_mask=batch.batch["response_mask"][:, 1:], + loss_agg_mode="token-mean", + ) + batch.meta_info["agg_entropy"] = agg_entropy + + batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] + metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) + else: + # Use zeros when optimization is enabled + batch.batch["old_log_probs"] = torch.zeros_like(batch.batch["attention_mask"][:, 1:]) if self.pipeline_config.adv_estimator == "gae": values = DataProto.materialize_concat(data_refs=values_refs) batch = batch.union(values) metrics_mgr.add_reduced_metrics(values.meta_info.pop("metrics", {})) - batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] - metrics_mgr.add_reduced_metrics(old_log_probs.meta_info.pop("metrics", {})) - # Mock ref_log_probs using old_log_probs if reference is disabled if not self.pipeline_config.enable_reference: batch.batch["ref_log_probs"] = batch.batch["old_log_probs"].clone() diff --git a/roll/utils/metrics/metrics_manager.py b/roll/utils/metrics/metrics_manager.py index 024f2423d..3d2ba15ce 100644 --- a/roll/utils/metrics/metrics_manager.py +++ b/roll/utils/metrics/metrics_manager.py @@ -183,7 +183,7 @@ def add_values_metrics(self, batch, prefix: str = "critic") -> Dict[str, Any]: response_mask = batch.batch["final_response_mask"].clone().bool() raw_advantages = batch.batch["raw_advantages"] returns = batch.batch["returns"] - agg_entropy = batch.meta_info.get("agg_entropy", 0.0) + agg_entropy = batch.meta_info.get("agg_entropy", torch.tensor(0)) max_score = 1 min_score = 0 From 3de37e1b5a4fbc232700d3b335c59086c3b5eeb4 Mon Sep 17 00:00:00 2001 From: "hongzhen.yj" Date: Tue, 18 Nov 2025 11:29:50 +0800 Subject: [PATCH 42/58] (fix): update image loading logic for byte data in rlvr_vlm_pipeline.py --- roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index efd99228e..052175c46 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -3,6 +3,7 @@ import os import uuid from functools import partial +from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union import datasets @@ -120,10 +121,18 @@ def encode_function( if image is None: image_flag[idx] = False try: - image_out = load_images(image if isinstance(image, (list, tuple)) else [image], timeout=None) + if isinstance(image, bytes): # bytes data + # TODO: support multiple images + image_out = Image.open(BytesIO(image)) + else: + image_out = load_images(image if isinstance(image, (list, tuple)) else [image], timeout=None) except Exception as e: - image_out = [Image.new("RGB", (224, 224), (255, 255, 255))] * len(image) - logger.error(f"Failed to get image: {image}") + if isinstance(image, bytes): + image_out = [Image.new("RGB", (224, 224), (255, 255, 255))] + logger.error(f"Failed to get image with type: {type(image)}") + else: + image_out = [Image.new("RGB", (224, 224), (255, 255, 255))] * len(image) + logger.error(f"Failed to get image: {image}") # since infer-image use pil image as input while train-engine use # processed data, process image here to make them use same image # refer to the following for Spatial Understanding with Qwen2.5-VL From 5caa55c5420b591d76629cdc37387e10b59535ed Mon Sep 17 00:00:00 2001 From: lzc410374 Date: Tue, 18 Nov 2025 11:53:35 +0800 Subject: [PATCH 43/58] (feat): mcore_adapter support qwen3vl. --- .../src/mcore_adapter/models/__init__.py | 1 + .../models/converter/template.py | 6 + .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../mcore_adapter/models/qwen3_vl/__init__.py | 86 ++++ .../models/qwen3_vl/config_qwen3_vl.py | 49 +++ .../models/qwen3_vl/modeling_qwen3_vl.py | 394 ++++++++++++++++++ .../models/qwen3_vl/rope_utils.py | 257 ++++++++++++ .../models/qwen3_vl/transformer_block.py | 360 ++++++++++++++++ 8 files changed, 1154 insertions(+), 1 deletion(-) create mode 100644 mcore_adapter/src/mcore_adapter/models/qwen3_vl/__init__.py create mode 100644 mcore_adapter/src/mcore_adapter/models/qwen3_vl/config_qwen3_vl.py create mode 100644 mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py create mode 100644 mcore_adapter/src/mcore_adapter/models/qwen3_vl/rope_utils.py create mode 100644 mcore_adapter/src/mcore_adapter/models/qwen3_vl/transformer_block.py diff --git a/mcore_adapter/src/mcore_adapter/models/__init__.py b/mcore_adapter/src/mcore_adapter/models/__init__.py index f8fea2dba..1839c09db 100644 --- a/mcore_adapter/src/mcore_adapter/models/__init__.py +++ b/mcore_adapter/src/mcore_adapter/models/__init__.py @@ -10,6 +10,7 @@ qwen3, qwen3_moe, qwen3_next, + qwen3_vl, ) from .auto import AutoConfig, AutoModel from .model_config import McaModelConfig diff --git a/mcore_adapter/src/mcore_adapter/models/converter/template.py b/mcore_adapter/src/mcore_adapter/models/converter/template.py index 901ba652c..9b509e820 100644 --- a/mcore_adapter/src/mcore_adapter/models/converter/template.py +++ b/mcore_adapter/src/mcore_adapter/models/converter/template.py @@ -310,6 +310,12 @@ def convert_hf_to_mca_config(self, hf_config, **kw_args): return AutoMcaModelConfig.for_model(self.hf_model_type, **kw_args) def convert_hf_to_mca_config_kws(self, hf_config: "PretrainedConfig", **kw_args): + # TODO: support text_config + if hasattr(hf_config, "text_config"): + text_config = hf_config.text_config.to_dict() + for k, v in text_config.items(): + setattr(hf_config, k, v) + for k, v in self.config_hf_to_mca.items(): if hasattr(hf_config, k): kw_args[v] = getattr(hf_config, k) diff --git a/mcore_adapter/src/mcore_adapter/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/mcore_adapter/src/mcore_adapter/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 1addc8bf9..14481889a 100644 --- a/mcore_adapter/src/mcore_adapter/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/mcore_adapter/src/mcore_adapter/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -23,7 +23,7 @@ def __init__(self, config: "Qwen2_5_VLConfig", **kwargs): if self.pre_process: self.vision_model = Qwen2_5_VisionTransformerPretrainedModel._from_config( Qwen2_5_VLVisionConfig(**config.vision_config), - attn_implementation="flash_attention_2", + attn_implementation="sdpa", torch_dtype=self.config.params_dtype, ).to(current_platform.current_device()) for param in self.vision_model.parameters(): diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/__init__.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/__init__.py new file mode 100644 index 000000000..054c12697 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/__init__.py @@ -0,0 +1,86 @@ +from ..converter.dist_converter import DistParallelConfig, default_dist_config, register_dist_config +from ..converter.template import ( + QKVBiasConverOp, + QKVConverOp, + RenameConverOp, + StackConverOp, + register_template, +) +from .config_qwen3_vl import Qwen3VLConfig +from .modeling_qwen3_vl import Qwen3VLModel + + +register_dist_config( + "qwen3_vl", + default_dist_config.merge_configs( + DistParallelConfig( + pre_process_weights=["vision_model.*"], + duplicated_weights=["vision_model.*"], + ) + ), +) + +register_template( + "qwen3_vl", + hf_layer_prefix="model.language_model.layers.", + config_hf_to_mca={ + "max_position_embeddings": "max_sequence_length", + "hidden_size": "hidden_size", + "attention_bias": "add_qkv_bias", + "head_dim": "kv_channels", + "num_attention_heads": "num_attention_heads", + "num_key_value_heads": "num_query_groups", + "num_hidden_layers": "num_layers", + "rms_norm_eps": "layernorm_epsilon", + "vocab_size": "padded_vocab_size", + "attention_dropout": "attention_dropout", + "rope_theta": "rotary_base", + "intermediate_size": "ffn_hidden_size", + "tie_word_embeddings": "tie_embeddings_and_output_weights", + + # vit related + "vision_start_token_id": "vision_start_token_id", + "vision_end_token_id": "vision_end_token_id", + "vision_token_id": "vision_token_id", + "image_token_id": "image_token_id", + "video_token_id": "video_token_id", + "vision_config": "vision_config", + "rope_scaling": "rope_scaling", + }, + constant_mca_config={ + "swiglu": True, + "position_embedding_type": "mrope", + "normalization": "RMSNorm", + "add_bias_linear": False, + "hidden_dropout": 0.0, + "rotary_percent": 1.0, + "qk_layernorm": True, + }, + weight_converters=[ + RenameConverOp(hf_names="lm_head.weight", mca_names="output_layer.weight"), + RenameConverOp(hf_names="model.language_model.embed_tokens.weight", mca_names="embedding.word_embeddings.weight"), + RenameConverOp(hf_names=".input_layernorm.weight", mca_names=".self_attention.linear_qkv.layer_norm_weight"), + RenameConverOp(hf_names=".self_attn.o_proj.weight", mca_names=".self_attention.linear_proj.weight"), + RenameConverOp(hf_names=".self_attn.q_norm.weight", mca_names=".self_attention.q_layernorm.weight"), + RenameConverOp(hf_names=".self_attn.k_norm.weight", mca_names=".self_attention.k_layernorm.weight"), + RenameConverOp(hf_names=".post_attention_layernorm.weight", mca_names=".mlp.linear_fc1.layer_norm_weight"), + RenameConverOp(hf_names="model.language_model.norm.weight", mca_names="decoder.final_layernorm.weight"), + StackConverOp( + hf_names=[".mlp.gate_proj.weight", ".mlp.up_proj.weight"], mca_names=".mlp.linear_fc1.weight", dim=0 + ), + RenameConverOp(hf_names=".mlp.down_proj.weight", mca_names=".mlp.linear_fc2.weight"), + QKVConverOp( + hf_names=[".self_attn.q_proj.weight", ".self_attn.k_proj.weight", ".self_attn.v_proj.weight"], + mca_names=".self_attention.linear_qkv.weight", + ), + QKVBiasConverOp( + hf_names=[".self_attn.q_proj.bias", ".self_attn.k_proj.bias", ".self_attn.v_proj.bias"], + mca_names=".self_attention.linear_qkv.bias", + ), + RenameConverOp(hf_names="model.visual.{}", mca_names="vision_model.{}"), + + ], +) + + +__all__ = ["Qwen3VLConfig", "Qwen3VLModel"] diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/config_qwen3_vl.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/config_qwen3_vl.py new file mode 100644 index 000000000..2a2774cd3 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/config_qwen3_vl.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass, field +from typing import Optional + +from transformers import PretrainedConfig + +from ...utils import get_logger +from ..auto.config_auto import register_config +from ..model_config import McaModelConfig + + +logger = get_logger(__name__) + +@register_config("qwen3_vl") +@dataclass +class Qwen3VLConfig(McaModelConfig): + vision_start_token_id: int = 151652 + vision_end_token_id: int = 151653 + vision_token_id: int = 151654 + image_token_id: int = 151655 + video_token_id: int = 151656 + vision_config: Optional[dict] = field( + default=None, + metadata={"help": "Vision model config."}, + ) + text_config: Optional[dict] = field( + default=None, + metadata={"help": "Text model config."}, + ) + rope_scaling: Optional[dict] = field( + default=None, + metadata={"help": "Rope scaling."}, + ) + + def __post_init__(self): + logger.info(f"{self.text_config}") + super().__post_init__() + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig + + if isinstance(self.vision_config, PretrainedConfig): + self.vision_config = self.vision_config.to_dict() + vision_config_obj = Qwen3VLVisionConfig(**self.vision_config) + self.merge_size = vision_config_obj.spatial_merge_size + self.pixel_values_dim = ( + vision_config_obj.patch_size + * vision_config_obj.patch_size + * vision_config_obj.in_channels + * vision_config_obj.temporal_patch_size + ) # 1176 + self.mrope_section = self.rope_scaling.get("mrope_section") diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py new file mode 100644 index 000000000..600fe1ab7 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py @@ -0,0 +1,394 @@ +from typing import List, Optional + +import torch +from megatron.core import mpu +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import deprecate_inference_params +from torch import Tensor + +from ..auto.modeling_auto import register_model +from ..model_factory import McaGPTModel +from ..model_utils import ModuleUtilsMixin +from .config_qwen3_vl import Qwen3VLConfig +from .rope_utils import Qwen3VLMultimodalRotaryEmbedding, get_rope_index +from .transformer_block import Qwen3VLTransformerBlock + + +class Qwen3VLGPTModel(McaGPTModel): + def __init__( + self, + config: Qwen3VLConfig, + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[float] = None, + **kwargs, + ) -> None: + super().__init__( + config, + rotary_percent=rotary_percent, + seq_len_interpolation_factor=seq_len_interpolation_factor, + **kwargs, + ) + + # rebuild rope + self.rotary_pos_emb = Qwen3VLMultimodalRotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=self.rotary_base, + ) + self.mrope_section = self.config.mrope_section + assert self.mrope_section is not None, ( + "mrope require mrope_section setting, but we got None from TransformerConfig" + ) + + # rebuild the transformer block + self.decoder = Qwen3VLTransformerBlock( + config=self.config, + spec=self.transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + vp_stage=self.vp_stage, + ) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + ) -> Tensor: + """Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) = self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + ) + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **(extra_block_kwargs or {}), + ) + + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + ) + + +@register_model("qwen3_vl") +class Qwen3VLModel(Qwen3VLGPTModel, ModuleUtilsMixin): + config_class = Qwen3VLConfig + + def __init__(self, config: "Qwen3VLConfig", **kwargs): + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + super().__init__(config, **kwargs) + + if mpu.get_pipeline_model_parallel_rank() == 0 and self.vp_stage == 0: + assert self.decoder.num_layers_per_pipeline_rank >= len( + config.vision_config.get("deepstack_visual_indexes", [8, 16, 24]) + ), "Current pp and vp not support deepstack" + + if self.pre_process: + self.vision_model = Qwen3VLVisionModel._from_config( + Qwen3VLVisionConfig(**config.vision_config), + attn_implementation="sdpa", + torch_dtype=self.config.params_dtype, + ).to(torch.cuda.current_device()) + for param in self.vision_model.parameters(): + setattr(param, "sequence_parallel", config.sequence_parallel) + + def _handle_missing_visual(self, inputs_embeds: "torch.FloatTensor"): + mock_pixel_values = torch.zeros( + 4, self.config.pixel_values_dim, device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + mock_grid_thw = torch.LongTensor([[1, 2, 2]]).to(inputs_embeds.device) + image_embeds, deepstack_image_embeds = self.vision_model(mock_pixel_values, grid_thw=mock_grid_thw) + inputs_embeds = inputs_embeds + image_embeds.mean() * 0 + return ( + inputs_embeds, + torch.zeros((inputs_embeds.size(1), inputs_embeds.size(0)), device=inputs_embeds.device, dtype=torch.bool), + deepstack_image_embeds, + ) + + def construct_inputs_embeds( + self, + input_ids: "torch.LongTensor", + inputs_embeds: "torch.FloatTensor", + pixel_values: "torch.Tensor", + grid_thw: "torch.LongTensor", + input_ranges: List[List[int]], + media_token_id: int, + ): + """ + inputs_embeds: [s, b, h] or [s/tp, b, h] when sequence parallel + ranges: sequence range + """ + image_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + flatten_grid_thw = torch.repeat_interleave(grid_thw, grid_thw[:, 0], dim=0) + flatten_grid_thw[:, 0] = 1 + image_embeds_seqlens = image_seqlens // (self.config.merge_size**2) + assert image_seqlens[-1] == pixel_values.shape[0], ( + f"pixel_values.shape[0] {pixel_values.shape[0]} != image_seqlens[-1] {image_seqlens[-1]}" + ) + assert sum([r[1] - r[0] for r in input_ranges]) == inputs_embeds.shape[0], ( + f"sum of input_ranges {input_ranges} not match inputs_embeds.shape {inputs_embeds.shape}" + ) + image_mask = input_ids == media_token_id + + valid_image_embeds_nums = [] # indicate the ranges of needed image embeds + required_pixel_values, required_grid_thws = [], [] # image features input to vision tower + added_image_indexes = [] + for i in range(image_mask.shape[0]): + for inputs_start, inputs_end in input_ranges: + valid_image_embeds_start = image_mask[:i].sum().item() + valid_image_embeds_start += image_mask[i, :inputs_start].sum().item() + embeds_num = image_mask[i, inputs_start:inputs_end].sum().item() + valid_image_embeds_end = valid_image_embeds_start + embeds_num + used_embeds_seqlen_start = 0 # embeds seqlens used in this range + new_embeds_seqlen_start = ( + 0 # embeds seqlens new added in this range, new_embeds_seqlen_start >= used_embeds_seqlen_start + ) + embeds_seqlen_end = image_embeds_seqlens[-1] + added_seqlen_before_used = 0 + for image_index, image_embeds_seqlen in enumerate(image_embeds_seqlens): + if valid_image_embeds_start < image_embeds_seqlen: + if image_index not in added_image_indexes: + required_grid_thws.append(flatten_grid_thw[image_index]) + added_image_indexes.append(image_index) + else: + new_embeds_seqlen_start = image_embeds_seqlen + else: + used_embeds_seqlen_start = image_embeds_seqlen + new_embeds_seqlen_start = image_embeds_seqlen + if image_index in added_image_indexes: + before_seqlen = 0 if image_index == 0 else image_embeds_seqlens[image_index - 1].item() + added_seqlen_before_used += image_embeds_seqlen - before_seqlen + if valid_image_embeds_end <= image_embeds_seqlen: + embeds_seqlen_end = image_embeds_seqlen + break + + if new_embeds_seqlen_start < embeds_seqlen_end: + required_pixel_values.append( + pixel_values[ + new_embeds_seqlen_start * (self.config.merge_size**2) : embeds_seqlen_end + * (self.config.merge_size**2) + ] + ) + embeds_needed_start = valid_image_embeds_start - used_embeds_seqlen_start + added_seqlen_before_used + embeds_needed_end = valid_image_embeds_end - used_embeds_seqlen_start + added_seqlen_before_used + if embeds_needed_start < embeds_needed_end: + valid_image_embeds_nums.append((embeds_needed_start, embeds_needed_end)) + + if len(required_pixel_values) == 0: + return self._handle_missing_visual(inputs_embeds) + + required_pixel_values = torch.cat(required_pixel_values, dim=0) + required_grid_thw = torch.stack(required_grid_thws, dim=0) + vision_model_dtype = self.vision_model.blocks[0].mlp.linear_fc1.weight.dtype + required_pixel_values = required_pixel_values.type(vision_model_dtype) + image_embeds, deepstack_image_embeds = self.vision_model(required_pixel_values, grid_thw=required_grid_thw) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + + image_mask = torch.cat( + [image_mask[:, inputs_start:inputs_end] for inputs_start, inputs_end in input_ranges], dim=1 + ) + needed_image_embeds_num = image_mask.sum().item() + needed_image_embeds = torch.zeros( + [needed_image_embeds_num] + list(image_embeds.shape[1:]), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + added_num = 0 + for start, end in valid_image_embeds_nums: + embeds_num = end - start + needed_image_embeds[added_num : added_num + embeds_num] = image_embeds[start:end] + added_num += embeds_num + assert added_num == needed_image_embeds_num + + inputs_embeds = inputs_embeds.transpose(0, 1) # [s, b, h] -> [b, s, h] + image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, needed_image_embeds) + inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() + + # construct deepstack embedding + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = [] + for deepstack_image_embed in deepstack_image_embeds: + needed_deepstack_image_embeds = torch.zeros( + [needed_image_embeds_num] + list(deepstack_image_embed.shape[1:]), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + added_num = 0 + for start, end in valid_image_embeds_nums: + embeds_num = end - start + needed_deepstack_image_embeds[added_num : added_num + embeds_num] = deepstack_image_embed[start:end] + added_num += embeds_num + assert added_num == needed_image_embeds_num + deepstack_visual_embeds.append(needed_deepstack_image_embeds) + + return inputs_embeds, visual_pos_masks, deepstack_visual_embeds + + def get_batch_on_this_cp_rank(self, batch, dim3_keys: List[str] = ["attention_mask"]): + # VLM need to view all input_ids and media features + loss_needed_items = { + "labels": batch.pop("labels", None), + } + loss_needed_items = super().get_batch_on_this_cp_rank(loss_needed_items, dim3_keys=dim3_keys) + batch.update(loss_needed_items) + return batch + + def get_input_ranges(self, total_seqlen): + # context parallel 的计算有问题 + slice_rank, slice_size = 0, 1 + if self.config.sequence_parallel: + slice_rank = mpu.get_tensor_model_parallel_rank() + slice_size = mpu.get_tensor_model_parallel_world_size() + + def get_sequence_range(start, end, rank, size): + return start + (end - start) * rank // size, start + (end - start) * (rank + 1) // size + + if self.config.context_parallel_size <= 1: + return [list(get_sequence_range(0, total_seqlen, slice_rank, slice_size))] + cp_rank = mpu.get_context_parallel_rank() + cp_size = mpu.get_context_parallel_world_size() + left_start = (total_seqlen // cp_size // 2) * cp_rank + left_end = (total_seqlen // cp_size // 2) * (cp_rank + 1) + right_start = total_seqlen - left_end + right_end = total_seqlen - left_start + slice_len = (left_end - left_start + right_end - right_start) // slice_size + start = left_start + slice_len * slice_rank + end = start + slice_len + if start >= left_end: + start = start - left_end + right_start + end = start + slice_len + return [[start, end]] + if end <= left_end: + return [[start, end]] + end = end - left_end + right_start + return [[start, left_end], [right_start, end]] + + def forward( + self, + input_ids: "torch.Tensor", + position_ids: Optional["torch.Tensor"] = None, + attention_mask: Optional["torch.Tensor"] = None, + decoder_input: Optional["torch.Tensor"] = None, + labels: Optional["torch.Tensor"] = None, + pixel_values: Optional["torch.Tensor"] = None, + pixel_values_videos: Optional["torch.Tensor"] = None, + image_grid_thw: Optional["torch.LongTensor"] = None, + video_grid_thw: Optional["torch.LongTensor"] = None, + **kwargs, + ) -> "torch.Tensor": + if position_ids is None and input_ids is not None: + position_ids, _ = get_rope_index(self.config, input_ids, image_grid_thw, video_grid_thw) + + cp_batch = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + if self.config.context_parallel_size > 1: + cp_batch = {k: v.clone() if v is not None else None for k, v in cp_batch.items()} + cp_batch = super().get_batch_on_this_cp_rank(cp_batch, dim3_keys=["attention_mask", "position_ids"]) + + if not self.pre_process or (pixel_values is None and pixel_values_videos is None) or decoder_input is not None: + return super().forward( + decoder_input=decoder_input, labels=labels, position_ids=position_ids, **cp_batch, **kwargs + ) + + inputs_ranges = self.get_input_ranges(input_ids.shape[1]) + + inputs_embeds = self.embedding(input_ids=cp_batch["input_ids"], position_ids=None) + if pixel_values is not None: + # get deepstack emb + inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.construct_inputs_embeds( + input_ids, + inputs_embeds, + pixel_values, + image_grid_thw, + inputs_ranges, + self.config.image_token_id, + ) + elif pixel_values_videos is not None: + inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.construct_inputs_embeds( + input_ids, + inputs_embeds, + pixel_values_videos, + video_grid_thw, + inputs_ranges, + self.config.video_token_id, + ) + return super().forward( + decoder_input=inputs_embeds, + labels=labels, + position_ids=position_ids, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **cp_batch, + **kwargs, + ) diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/rope_utils.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/rope_utils.py new file mode 100644 index 000000000..61ed6e3bd --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/rope_utils.py @@ -0,0 +1,257 @@ +from typing import Optional, Tuple + +import torch +from megatron.core import parallel_state +from megatron.core.models.common.embeddings.rope_utils import ( + _apply_rotary_pos_emb_bshd, + get_pos_emb_on_this_cp_rank, +) +from torch import nn + +from .config_qwen3_vl import Qwen3VLConfig + + +class Qwen3VLMultimodalRotaryEmbedding(nn.Module): + """Rotary Embedding for language model. + Based on https://github.com/lostkevin/Pai-Megatron-Patch/blob/ + 11b44c3cfe95defb006641b085f8657702f25f10/megatron_patch/model/qwen3_vl/rotary_pos_embedding.py + + Args: + kv_channels (int): Projection weights dimension in multi-head attention. Obtained + from transformer config + rotary_percent (float): Percent of rotary dimension to use for rotary position + embeddings. + rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings. + Defaults to False. + seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE + for longer sequences. The value must be a float larger than 1.0. Defaults to None + rotary_base (int, optional): Base period for rotary position embeddings. Defaults to + 10000. + """ + + def __init__( + self, + kv_channels: int, + rotary_percent: float, + rotary_interleaved: bool = False, + seq_len_interpolation_factor: float = None, + rotary_base: int = 10000, + ) -> None: + super().__init__() + + dim = kv_channels + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.rotary_interleaved = rotary_interleaved + assert not self.rotary_interleaved, "Qwen3VLMultimodalRotaryEmbedding does not support rotary_interleaved" + + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.inv_freq = 1.0 / ( + rotary_base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) / dim) + ) + + self.is_thd_format = False # if is thd format, we do not need to split the rotary_pos_emb along CP + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + def forward(self, position_ids: torch.Tensor, mrope_section: list[int]) -> torch.Tensor: + """Forward pass of multimodal RoPE embedding. + + Args: + position_ids (torch.Tensor): A position_id tensor with shape [3, batchsize, seqlens] + mrope_section (list[int]): Multimodal rope section is for channel dimension of temporal, + height and width in rope calculation. + + Returns: + Tensor: Embeddings after applying RoPE. + """ + seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + + inv_freq_expanded = self.inv_freq[None, None, :, None].expand(3, seq.shape[1], -1, 1) # shape (3, bs, dim, 1) + seq_expanded = seq[:, :, None, :].float() # shape (3, bs, 1, seq_length) + freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) # shape (3, bs, seq_length, dim) + freqs = self.apply_interleaved_mrope(freqs, mrope_section) # shape (bs, seq_length, dim) + + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + # sin, sin, ..., cos, cos, ... + emb = torch.cat((freqs, freqs), dim=-1) # shape (bs, seq_length, 2 * dim) + + # shape (seq_length, bs, 1, 2 * dim) + emb = emb[..., None, :].transpose(0, 1).contiguous() + if parallel_state.get_context_parallel_world_size() > 1: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group()) + return emb + + +# def apply_rotary_pos_emb_thd_absolute( +# t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor, rotary_interleaved: bool = False +# ) -> torch.Tensor: +# """A baseline implementation of applying RoPE for `thd` format. + +# Args: +# t (Tensor): Input tensor T is of shape [t, h, d] +# cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, +# with shape [b + 1] and dtype torch.int32. +# freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + +# Returns: +# Tensor: Shape [t, h, d]. The input tensor after applying RoPE. +# """ +# return _apply_rotary_pos_emb_bshd(t[:, None], freqs, rotary_interleaved=rotary_interleaved).squeeze(1) + + +# def apply_rotary_pos_emb_absolute( +# t: torch.Tensor, +# freqs: torch.Tensor, +# config: Qwen3VLConfig, +# cu_seqlens: Optional[torch.Tensor] = None, +# ): +# """ +# Reroute to the appropriate apply_rotary_pos_emb function depending on +# bshd (conventional) / thd (packed seq) format + +# In Qwen3-VL, the shape of freqs is (seq_length, bs, 1, 2 * dim) instead of [max_seqlen, 1, 1, 2 * dim] +# """ +# assert not config.apply_rope_fusion + +# if cu_seqlens is None: +# result = _apply_rotary_pos_emb_bshd(t, freqs, rotary_interleaved=config.rotary_interleaved) +# else: +# result = apply_rotary_pos_emb_thd_absolute(t, cu_seqlens, freqs, rotary_interleaved=config.rotary_interleaved) + +# return result + + +# copy from transformers==4.57.0 +def get_rope_index( + config: Qwen3VLConfig, + input_ids: torch.LongTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to separate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = config.merge_size + image_token_id = config.image_token_id + video_token_id = config.video_token_id + vision_start_token_id = config.vision_start_token_id + + mrope_position_deltas = [] + attention_mask = torch.ones(input_ids.shape, dtype=input_ids.dtype, device=input_ids.device) + if image_grid_thw is not None or video_grid_thw is not None: + total_input_ids = input_ids + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + if attention_mask is not None: + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/transformer_block.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/transformer_block.py new file mode 100644 index 000000000..02775d448 --- /dev/null +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/transformer_block.py @@ -0,0 +1,360 @@ +from contextlib import nullcontext +from typing import Optional, Union + +import torch +from megatron.core import parallel_state, tensor_parallel +from megatron.core.enums import Fp8Recipe +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.transformer_block import ( + TransformerBlock, +) + +# from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.utils import ( + WrappedTensor, + deprecate_inference_params, + make_viewless_tensor, +) +from torch import Tensor + + +try: + import transformer_engine.pytorch as te # pylint: disable=unused-import + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +te_checkpoint = None +if HAVE_TE: + from megatron.core.extensions.transformer_engine import te_checkpoint + + +class Qwen3VLTransformerBlock(TransformerBlock): + """Transformer class.""" + + def _checkpointed_forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor, + context_mask: Tensor, + rotary_pos_emb: Tensor, + attention_bias: Tensor, + packed_seq_params: PackedSeqParams, + use_inner_fp8_context: bool, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + ): + """Forward method with activation checkpointing.""" + + def custom(start: int, end: int): + def custom_forward( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ): + for index in range(start, end): + layer = self._get_layer(index) + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + with inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + ) + + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + if self.config.fp8: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + parallel_state.get_tensor_model_parallel_group(), + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.config.recompute_method == "uniform": + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + assert self.config.recompute_num_layers == 1 + while layer_idx < self.num_layers_per_pipeline_rank: + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + self.config.recompute_num_layers) + ) + + if self.pre_process and deepstack_visual_embeds is not None: + layer = self._get_layer(layer_idx) + assert layer_idx == layer.layer_number - 1 + if layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == "block": + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + if self.config.fp8 and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if ( + layer_idx >= recompute_skip_num_layers + and layer_idx + < self.config.recompute_num_layers + recompute_skip_num_layers + ): + hidden_states, context = checkpoint_handler( + custom(layer_idx, layer_idx + 1) + ) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + ) + + if self.pre_process and deepstack_visual_embeds is not None: + layer = self._get_layer(layer_idx) + assert layer_idx == layer.layer_number - 1 + if layer_idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_idx], + ) + else: + raise ValueError("Invalid activation recompute method.") + + return hidden_states + + def forward( + self, + hidden_states: Union[Tensor, WrappedTensor], + attention_mask: Optional[Tensor], + context: Optional[Tensor] = None, + context_mask: Optional[Tensor] = None, + rotary_pos_emb: Optional[Tensor] = None, + rotary_pos_cos: Optional[Tensor] = None, + rotary_pos_sin: Optional[Tensor] = None, + attention_bias: Optional[Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + if self.pre_process and deepstack_visual_embeds is not None: + assert len(deepstack_visual_embeds) < len( + self.layers + ), "the deepstack_visual_embeds should be on the first pp-stage" + + inference_context = deprecate_inference_params( + inference_context, inference_params + ) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + use_outer_fp8_context = ( + self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed + ) + use_inner_fp8_context = ( + self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + ) + outer_fp8_context = ( + get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + ) + + with rng_context, outer_fp8_context: + # Forward pass. + if self.config.recompute_granularity == "full" and self.training: + hidden_states = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_fp8_context=use_inner_fp8_context, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + else: + for l_no, layer in enumerate(self.layers): + inner_fp8_context = ( + get_fp8_context(self.config, layer.layer_number - 1) + if use_inner_fp8_context + else nullcontext() + ) + with self.offload_context, inner_fp8_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + + if self.pre_process and deepstack_visual_embeds is not None: + assert l_no == layer.layer_number - 1 + if l_no in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[l_no], + ) + + if ( + torch.is_grad_enabled() + and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async( + hidden_states + ) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + return hidden_states + + def _deepstack_process( + self, + hidden_states: torch.Tensor, + visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor, + ): + hidden_states = hidden_states.transpose(0, 1).contiguous() + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + hidden_states = hidden_states.transpose(0, 1).contiguous() + return hidden_states From b5cd1eac636b587e0a788b3694aa8dadb65f812f Mon Sep 17 00:00:00 2001 From: "hongzhen.yj" Date: Tue, 18 Nov 2025 19:21:01 +0800 Subject: [PATCH 44/58] (fix): add force_vit flags for image and video processing in Qwen3 VL model. --- .../mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py index 600fe1ab7..3b6ec224a 100644 --- a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py @@ -345,6 +345,8 @@ def forward( video_grid_thw: Optional["torch.LongTensor"] = None, **kwargs, ) -> "torch.Tensor": + force_vit_image = kwargs.pop("force_vit_image", False) + force_vit_video = kwargs.pop("force_vit_video", False) if position_ids is None and input_ids is not None: position_ids, _ = get_rope_index(self.config, input_ids, image_grid_thw, video_grid_thw) @@ -374,7 +376,9 @@ def forward( inputs_ranges, self.config.image_token_id, ) - elif pixel_values_videos is not None: + elif force_vit_image: + inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self._handle_missing_visual(inputs_embeds) + if pixel_values_videos is not None: inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self.construct_inputs_embeds( input_ids, inputs_embeds, @@ -383,6 +387,8 @@ def forward( inputs_ranges, self.config.video_token_id, ) + elif force_vit_video: + inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self._handle_missing_visual(inputs_embeds) return super().forward( decoder_input=inputs_embeds, labels=labels, From e30bb72aa7af67e47342a41cae5ef83ee94ecac9 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:19:36 +0800 Subject: [PATCH 45/58] (feat): add qwen3-vl example. --- .../rlvr_megatron.yaml | 161 ++++++++++++++++++ .../run_rlvr_pipeline.sh | 5 + 2 files changed, 166 insertions(+) create mode 100644 examples/qwen3-vl-4B-rlvr_megatron/rlvr_megatron.yaml create mode 100644 examples/qwen3-vl-4B-rlvr_megatron/run_rlvr_pipeline.sh diff --git a/examples/qwen3-vl-4B-rlvr_megatron/rlvr_megatron.yaml b/examples/qwen3-vl-4B-rlvr_megatron/rlvr_megatron.yaml new file mode 100644 index 000000000..4b6c12123 --- /dev/null +++ b/examples/qwen3-vl-4B-rlvr_megatron/rlvr_megatron.yaml @@ -0,0 +1,161 @@ +defaults: + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen3_vl_4B_rlvr" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/yuzhao/models + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yuzhao/llm/tensorboard + +save_steps: 20 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 256 +num_return_sequences_in_group: 8 +is_num_return_sequences_expand: true +prompt_length: 2048 +response_length: 4096 + +ppo_epochs: 1 +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 10.0 +whiten_advantages: false +init_kl_coef: 0.0 +adv_estimator: "grpo" +use_kl_loss: true +kl_loss_coef: 1.0e-2 + +pretrain: Qwen/Qwen3-VL-4B-Thinking + +validation: + data_args: + file_name: + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/test/test_math_megabench_237.parquet + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/test/test_detection_coco_test_multi_2000.parquet + dataset_dir: ./ + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + eval_steps: ${eval_steps} + +actor_train: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + freeze_module_prefix: vision_model + training_args: + learning_rate: 1.0e-6 + weight_decay: 1.0e-2 + per_device_train_batch_size: 2 + gradient_accumulation_steps: 64 + warmup_steps: 0 + num_train_epochs: 50 + data_args: + # use One-RL-to-See-Them-All/Orsta-Data-47k as train dataset + # download from https://huggingface.co/datasets/One-RL-to-See-Them-All/Orsta-Data-47k + file_name: + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/train/train_detection_v3det_4000.parquet + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/train/train_math_mmmath_3539.parquet + domain_interleave_probs: + math: 0.5 + cv_detection: 0.5 + dataset_dir: ./ + messages: prompt + preprocessing_num_workers: 32 + strategy_args: + strategy_name: megatron_train + strategy_config: + sequence_parallel: true + tensor_model_parallel_size: 1 + context_parallel_size: 1 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + overlap_grad_reduce: true + use_distributed_optimizer: true + bf16: true + device_mapping: list(range(0,8)) + infer_batch_size: 2 + +actor_infer: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + num_gpus_per_worker: 1 + device_mapping: list(range(0,8)) + infer_batch_size: 32 + +reference: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + strategy_args: + strategy_name: megatron_infer + strategy_config: + sequence_parallel: true + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + bf16: true + device_mapping: list(range(0,8)) + infer_batch_size: 8 + +rewards: + math: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${pretrain} + # data source whose ability is math in One-RL-to-See-Them-All/Orsta-Data-47k + tag_included: [mm_math, megabench_math] + world_size: 8 + infer_batch_size: 1 + cv_detection: + worker_cls: roll.pipeline.rlvr.rewards.detection_reward_worker.DetectionRewardWorker + model_args: + model_name_or_path: ${pretrain} + # data source whose ability is cv_detection in One-RL-to-See-Them-All/Orsta-Data-47k + tag_included: [v3det_train, object365_train, coco_val_multi_test] + world_size: 8 + infer_batch_size: 1 diff --git a/examples/qwen3-vl-4B-rlvr_megatron/run_rlvr_pipeline.sh b/examples/qwen3-vl-4B-rlvr_megatron/run_rlvr_pipeline.sh new file mode 100644 index 000000000..9a8f0ef42 --- /dev/null +++ b/examples/qwen3-vl-4B-rlvr_megatron/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_vl_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_megatron From 2f9f2dfae7bc912164519b7d8d5b179e18fc23cd Mon Sep 17 00:00:00 2001 From: "weixun.wwx" Date: Fri, 21 Nov 2025 14:09:42 +0800 Subject: [PATCH 46/58] (feat): mock infer. --- roll/configs/worker_config.py | 4 +- roll/distributed/strategy/factory.py | 2 + roll/distributed/strategy/mock_strategy.py | 95 ++++++++++++++++++++++ 3 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 roll/distributed/strategy/mock_strategy.py diff --git a/roll/configs/worker_config.py b/roll/configs/worker_config.py index eae29bad0..6859ed062 100644 --- a/roll/configs/worker_config.py +++ b/roll/configs/worker_config.py @@ -12,11 +12,11 @@ @dataclass class StrategyArguments: strategy_name: Literal[ - "deepspeed_train", "hf_infer", "deepspeed_infer", "vllm", "sglang", "megatron_infer", "megatron_train", "diffusion_deepspeed_train" + "deepspeed_train", "hf_infer", "deepspeed_infer", "vllm", "sglang", "megatron_infer", "megatron_train", "mock_infer", "diffusion_deepspeed_train" ] = field( default="deepspeed_train", metadata={ - "help": "The name of the strategy. Options: 'deepspeed_train', 'diffusion_deepspeed_train', 'hf_infer', 'deepspeed_infer', 'vllm', 'sglang', " + "help": "The name of the strategy. Options: 'deepspeed_train', 'diffusion_deepspeed_train', 'hf_infer', 'deepspeed_infer', 'mock_infer', 'vllm', 'sglang', " "'megatron_infer', 'megatron_train'." }, ) diff --git a/roll/distributed/strategy/factory.py b/roll/distributed/strategy/factory.py index ba35598d8..e408fd929 100644 --- a/roll/distributed/strategy/factory.py +++ b/roll/distributed/strategy/factory.py @@ -24,6 +24,8 @@ def create_strategy(worker: Worker) -> Union[InferenceStrategy, TrainStrategy]: from roll.distributed.strategy.megatron_strategy import MegatronInferStrategy as strategy_cls elif strategy_name == "megatron_train": from roll.distributed.strategy.megatron_strategy import MegatronTrainStrategy as strategy_cls + elif strategy_name == "mock_infer": + from roll.distributed.strategy.mock_strategy import MockInferStrategy as strategy_cls else: raise ValueError(f"Unknown strategy name: {strategy_name}") diff --git a/roll/distributed/strategy/mock_strategy.py b/roll/distributed/strategy/mock_strategy.py new file mode 100644 index 000000000..fcf626732 --- /dev/null +++ b/roll/distributed/strategy/mock_strategy.py @@ -0,0 +1,95 @@ +from concurrent import futures +from collections import defaultdict +from datetime import timedelta +from typing import List, Optional, Callable, Dict, Tuple + +import deepspeed +import torch +import torch.distributed as dist +from accelerate import cpu_offload_with_hook +from accelerate.hooks import UserCpuOffloadHook +from roll.utils.collective import collective +from torch.nn.utils.rnn import pad_sequence +from transformers import set_seed + +from roll.datasets.collator import collate_fn_to_dict_list +from roll.distributed.scheduler.protocol import DataProto +from roll.distributed.strategy.strategy import InferenceStrategy +from roll.models.func_providers import log_probs_forward_step_func +from roll.models.model_providers import default_tokenizer_provider +from roll.utils.logging import get_logger +from roll.utils.offload_states import OffloadStateType, offload_hf_model, load_hf_model +from roll.platforms import current_platform + +logger = get_logger() + + +class MockInferStrategy(InferenceStrategy): + strategy_name = "mock_infer" + + def __init__(self, worker: "Worker"): + super().__init__(worker) + self.executor: futures.ThreadPoolExecutor = futures.ThreadPoolExecutor(max_workers=1) + self.generate_config = None + + def initialize(self, model_provider): + set_seed(seed=self.worker.pipeline_config.seed) + dist.init_process_group(backend=current_platform.communication_backend, timeout=timedelta(minutes=self.worker_config.backend_timeout)) + dist.all_reduce(torch.zeros(1).to(current_platform.device_type)) + + self.worker.rank_info.dp_rank = dist.get_rank() + self.worker.rank_info.dp_size = dist.get_world_size() + + # 是否最少存个tokenizer + self.tokenizer = default_tokenizer_provider(model_args=self.worker_config.model_args) + # TODO:是否需要model? + # logger.info(f"{self.model}") + + def forward_step( + self, + batch: DataProto, + forward_func: Callable[[DataProto, torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]], + ) -> Dict[str, torch.Tensor]: + # TODO 补充一下results的格式 + input_ids = batch.batch["input_ids"] + # 创建 placeholder log_probs,形状与 input_ids 相同 + seq_len = input_ids.size(1) + target_len = max(seq_len - 1, 0) + log_probs = torch.zeros( + input_ids.size(0), target_len, dtype=torch.float32, device=input_ids.device + ) + entropy = torch.zeros( + input_ids.size(0), target_len, dtype=torch.float32, device=input_ids.device + ) + results = {"log_probs": log_probs, "entropy": entropy} + return results + + def generate(self, batch: DataProto, generation_config): + # TODO 补充一下output的格式 + input_ids = batch.batch["input_ids"] + batch_size = input_ids.shape[0] + input_length = input_ids.shape[1] + # 获取生成的最大新token数,如果没有则使用默认值 + max_new_tokens = generation_config.get("max_new_tokens", generation_config.get("max_length", 50)) + # 生成的序列长度 = 输入长度 + 新生成的token数 + output_length = input_length + max_new_tokens + # 创建 placeholder output,形状为 (batch_size, output_length) + output = torch.zeros(batch_size, output_length, dtype=input_ids.dtype, device=input_ids.device) + return output + + def unwrap_model(self): + # return self.model + raise NotImplementedError + + def update_parameter(self, model_update_name, parameter_name, weight, ranks_in_worker): + logger.warning(f"update_parameter method is not implemented in {self.strategy_name} strategy") + + def update_parameter_in_bucket(self, model_update_name, meta_infos, buffer, ranks_in_worker): + logger.warning(f"update_parameter_in_bucket method is not implemented in {self.strategy_name} strategy") + + # offload/load 相关接口 + def load_states(self, *args, **kwargs): + logger.warning(f"load_states method is not implemented in {self.strategy_name} strategy") + + def offload_states(self, include=None, non_blocking=False): + logger.warning(f"offload_states method is not implemented in {self.strategy_name} strategy") From 3e4633ea314db4969ae92adc4903a57e7b51c435 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Fri, 5 Dec 2025 17:20:59 +0800 Subject: [PATCH 47/58] (feat): add qwen3-vl 32B example. --- .../rlvr_megatron.yaml | 162 ++++++++++++++++++ .../run_rlvr_pipeline.sh | 5 + .../models/qwen3_vl/modeling_qwen3_vl.py | 2 +- 3 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 examples/qwen3-vl-32B-rlvr_megatron/rlvr_megatron.yaml create mode 100644 examples/qwen3-vl-32B-rlvr_megatron/run_rlvr_pipeline.sh diff --git a/examples/qwen3-vl-32B-rlvr_megatron/rlvr_megatron.yaml b/examples/qwen3-vl-32B-rlvr_megatron/rlvr_megatron.yaml new file mode 100644 index 000000000..f578ae952 --- /dev/null +++ b/examples/qwen3-vl-32B-rlvr_megatron/rlvr_megatron.yaml @@ -0,0 +1,162 @@ +defaults: + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen3_vl_32B_rlvr" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/yuzhao/models + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yuzhao/llm/tensorboard + +save_steps: 20 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 256 +num_return_sequences_in_group: 8 +is_num_return_sequences_expand: true +prompt_length: 2048 +response_length: 4096 + +ppo_epochs: 1 +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 10.0 +whiten_advantages: false +init_kl_coef: 0.0 +adv_estimator: "grpo" +use_kl_loss: true +kl_loss_coef: 1.0e-2 + +pretrain: Qwen/Qwen3-VL-32B-Thinking + +validation: + data_args: + file_name: + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/test/test_math_megabench_237.parquet + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/test/test_detection_coco_test_multi_2000.parquet + dataset_dir: ./ + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + eval_steps: ${eval_steps} + +actor_train: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + freeze_module_prefix: vision_model + training_args: + learning_rate: 1.0e-6 + weight_decay: 1.0e-2 + per_device_train_batch_size: 2 + gradient_accumulation_steps: 64 + warmup_steps: 0 + num_train_epochs: 50 + data_args: + # use One-RL-to-See-Them-All/Orsta-Data-47k as train dataset + # download from https://huggingface.co/datasets/One-RL-to-See-Them-All/Orsta-Data-47k + file_name: + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/train/train_detection_v3det_4000.parquet + - /data/oss_bucket_0/yuzhao/data/One-RL-to-See-Them-All/Orsta-Data-47k/train/train_math_mmmath_3539.parquet + domain_interleave_probs: + math: 0.5 + cv_detection: 0.5 + dataset_dir: ./ + messages: prompt + preprocessing_num_workers: 32 + strategy_args: + strategy_name: megatron_train + strategy_config: + sequence_parallel: true + tensor_model_parallel_size: 4 + context_parallel_size: 2 + expert_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + overlap_grad_reduce: true + use_distributed_optimizer: true + bf16: true + device_mapping: list(range(0,32)) + infer_batch_size: 8 + +actor_infer: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.7 + block_size: 16 + max_model_len: 8192 + num_gpus_per_worker: 4 + device_mapping: list(range(0,32)) + infer_batch_size: 32 + +reference: + model_args: + flash_attn: fa2 + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + strategy_args: + strategy_name: megatron_infer + strategy_config: + sequence_parallel: true + tensor_model_parallel_size: 2 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + bf16: true + device_mapping: list(range(0,32)) + infer_batch_size: 8 + +rewards: + math: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${pretrain} + # data source whose ability is math in One-RL-to-See-Them-All/Orsta-Data-47k + tag_included: [mm_math, megabench_math] + world_size: 8 + infer_batch_size: 1 + cv_detection: + worker_cls: roll.pipeline.rlvr.rewards.detection_reward_worker.DetectionRewardWorker + model_args: + model_name_or_path: ${pretrain} + # data source whose ability is cv_detection in One-RL-to-See-Them-All/Orsta-Data-47k + tag_included: [v3det_train, object365_train, coco_val_multi_test] + world_size: 8 + infer_batch_size: 1 diff --git a/examples/qwen3-vl-32B-rlvr_megatron/run_rlvr_pipeline.sh b/examples/qwen3-vl-32B-rlvr_megatron/run_rlvr_pipeline.sh new file mode 100644 index 000000000..9a8f0ef42 --- /dev/null +++ b/examples/qwen3-vl-32B-rlvr_megatron/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_vl_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_megatron diff --git a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py index 3b6ec224a..fc0494362 100644 --- a/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py +++ b/mcore_adapter/src/mcore_adapter/models/qwen3_vl/modeling_qwen3_vl.py @@ -356,7 +356,7 @@ def forward( } if self.config.context_parallel_size > 1: cp_batch = {k: v.clone() if v is not None else None for k, v in cp_batch.items()} - cp_batch = super().get_batch_on_this_cp_rank(cp_batch, dim3_keys=["attention_mask", "position_ids"]) + cp_batch = super().get_batch_on_this_cp_rank(cp_batch, dim3_keys=[]) if not self.pre_process or (pixel_values is None and pixel_values_videos is None) or decoder_input is not None: return super().forward( From 3657919e05795aa89eca829b7b2ed80d69838c16 Mon Sep 17 00:00:00 2001 From: "fengjingxuan.fjx" Date: Mon, 24 Nov 2025 10:54:49 +0800 Subject: [PATCH 48/58] (feat): add sequence packing for sft pipeline and distill pipeline, optimize memory usage during top-k logits computation. --- .../distill_megatron.yaml | 12 +- .../qwen2.5-7B-sft_megatron/sft_config.yaml | 1 + roll/configs/base_config.py | 13 + roll/configs/worker_config.py | 8 +- .../distributed/strategy/megatron_strategy.py | 738 +++++++++++++++--- roll/distributed/strategy/strategy.py | 155 ++-- roll/pipeline/agentic/agentic_config.py | 2 + roll/pipeline/distill/distill_config.py | 9 +- roll/pipeline/distill/distill_pipeline.py | 1 + roll/pipeline/distill/distill_worker.py | 103 ++- .../pipeline/distill/logits_transfer_group.py | 3 +- roll/pipeline/distill/various_divergence.py | 194 +---- roll/pipeline/dpo/dpo_config.py | 2 + roll/pipeline/rlvr/rlvr_config.py | 2 + roll/pipeline/sft/sft_config.py | 2 + roll/pipeline/sft/sft_pipeline.py | 1 - roll/pipeline/sft/sft_worker.py | 10 +- roll/utils/sequence_packing.py | 356 +++++++++ 18 files changed, 1232 insertions(+), 380 deletions(-) create mode 100644 roll/utils/sequence_packing.py diff --git a/examples/qwen2.5-7B-distill_megatron/distill_megatron.yaml b/examples/qwen2.5-7B-distill_megatron/distill_megatron.yaml index cbb845a00..8495c9447 100644 --- a/examples/qwen2.5-7B-distill_megatron/distill_megatron.yaml +++ b/examples/qwen2.5-7B-distill_megatron/distill_megatron.yaml @@ -28,7 +28,7 @@ distill_on_prompt: False logits_transfer_backend: "nccl-only" # support "ipc+nccl", "nccl_only" and "ray" -sequence_length: 1024 +sequence_length: 2048 max_grad_norm: 1.0 question_key: question_zh @@ -43,8 +43,8 @@ student: training_args: learning_rate: 2.0e-5 lr_scheduler_type: constant - per_device_train_batch_size: 2 - gradient_accumulation_steps: 1 + per_device_train_batch_size: 8 + gradient_accumulation_steps: 4 warmup_steps: 0 num_train_epochs: 1 @@ -57,10 +57,12 @@ student: strategy_name: megatron_train strategy_config: tensor_model_parallel_size: 2 + sequence_parallel: True pipeline_model_parallel_size: 2 context_parallel_size: 2 use_distributed_optimizer: true recompute_granularity: full + use_sequence_packing: True device_mapping: list(range(0,8)) teacher: @@ -72,14 +74,16 @@ teacher: template: qwen2_5 training_args: # teacher forward micro_batch_size - per_device_train_batch_size: 1 + per_device_train_batch_size: 8 strategy_args: strategy_name: megatron_infer strategy_config: tensor_model_parallel_size: 2 + sequence_parallel: True pipeline_model_parallel_size: 2 context_parallel_size: 2 bf16: true + use_sequence_packing: True device_mapping: list(range(0,8)) system_envs: diff --git a/examples/qwen2.5-7B-sft_megatron/sft_config.yaml b/examples/qwen2.5-7B-sft_megatron/sft_config.yaml index 62b5636ef..a0a721c89 100644 --- a/examples/qwen2.5-7B-sft_megatron/sft_config.yaml +++ b/examples/qwen2.5-7B-sft_megatron/sft_config.yaml @@ -65,5 +65,6 @@ sft_train: pipeline_model_parallel_size: 2 use_distributed_optimizer: true context_parallel_size: 2 + use_sequence_packing: True device_mapping: list(range(0,8)) infer_batch_size: 2 diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index d30dce574..ca7100a2c 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -289,6 +289,19 @@ def set_max_steps(self, max_steps: int): if hasattr(attribute, "training_args"): setattr(attribute.training_args, "max_steps", max_steps) + def validate_worker_config(self): + # check if current worker supports sequence packing + allowed_names = { + 'student', 'teacher', 'sft_train', + } + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, WorkerConfig) and attr.use_sequence_packing: + if attr.name not in allowed_names: + raise ValueError( + f"Worker '{attr.name}' (from field '{attr_name}') don't support use sequence packing now" + ) + @dataclass class PPOConfig(BaseConfig): # role related diff --git a/roll/configs/worker_config.py b/roll/configs/worker_config.py index 6859ed062..0c7c9ea12 100644 --- a/roll/configs/worker_config.py +++ b/roll/configs/worker_config.py @@ -149,6 +149,13 @@ class WorkerConfig: metadata={"help": "Whether offload nccl buffer to save gpu memory."} ) + # sequence packing + use_sequence_packing: bool = field( + default=False, + metadata={"help": "Concatenates multiple sequences into a single “packed” sequence, eliminating most padding. " + "Only supported in the megatron strategy"} + ) + def __post_init__(self): if self.strategy_args is not None: @@ -186,7 +193,6 @@ def __post_init__(self): elif self.model_args.dtype == "fp16": self.training_args.fp16 = True - def is_colocated(actor_train: WorkerConfig, actor_infer: WorkerConfig): train_devices = set(actor_train.device_mapping or []) infer_devices = set(actor_infer.device_mapping or []) diff --git a/roll/distributed/strategy/megatron_strategy.py b/roll/distributed/strategy/megatron_strategy.py index 55b483fe4..e4038b8f5 100644 --- a/roll/distributed/strategy/megatron_strategy.py +++ b/roll/distributed/strategy/megatron_strategy.py @@ -19,13 +19,14 @@ from megatron.core.models.common.embeddings import RotaryEmbedding from megatron.core.optimizer import MegatronOptimizer, OptimizerConfig from megatron.core.pipeline_parallel import get_forward_backward_func -from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region +from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region from megatron.core.transformer.moe.moe_utils import ( clear_aux_losses_tracker, get_moe_layer_wise_logging_tracker, reduce_aux_losses_tracker_across_ranks, ) from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy +from megatron.core.packed_seq_params import PackedSeqParams from mcore_adapter import TrainingArguments from mcore_adapter.checkpointing import get_checkpoint_dir, load_state_dict_from_checkpoint @@ -52,8 +53,8 @@ from roll.utils.logging import get_logger from roll.utils.offload_states import OffloadStateType from roll.utils.dynamic_batching import make_micro_batch_iter_for_dynamic_batching -from roll.platforms import current_platform +from roll.platforms import current_platform logger = get_logger() @@ -74,6 +75,8 @@ def __init__(self, worker: Worker): self.forward_backward_func = None self.seq_length = None self.use_remove_padding = self.worker_config.use_remove_padding + self.use_sequence_packing = self.worker_config.use_sequence_packing + self.max_packed_len = None # hard to impl with offload states assert not self.megatron_train_args.overlap_param_gather, "overlap_param_gather is not supported" if self.worker_config.use_remove_padding: @@ -143,15 +146,19 @@ def forward_step( self.model.eval() output_on_all_tp_cp_ranks = batch.meta_info.get("output_on_all_tp_cp_ranks", False) if self.worker_config.use_dynamic_batching_in_infer: - data_iterator = make_micro_batch_iter_for_dynamic_batching(batch) + micro_batches_list = list(make_micro_batch_iter_for_dynamic_batching(batch)) num_microbatches = batch.meta_info["num_micro_batchs"] micro_batch_size = 1 else: batch_size = batch.batch.batch_size[0] micro_batch_size = batch.meta_info["micro_batch_size"] num_microbatches = max(batch_size // micro_batch_size, 1) - micro_batches = batch.chunk(chunks=num_microbatches) - data_iterator = [iter(micro_batches) for _ in range(len(self.model))] + micro_batches_list = batch.chunk(chunks=num_microbatches) + if self.use_sequence_packing: + micro_batch_size = 1 + self.max_packed_len = self._get_max_packed_len(micro_batches_list) + + data_iterator = [iter(micro_batches_list) for _ in range(len(self.model))] with disable_gradients(models=self.model.get_models()): # List 是每个 micro-batch 构成的 losses_reduced: List[Dict[str, torch.Tensor]] = self.forward_backward_func( @@ -159,7 +166,7 @@ def forward_step( data_iterator=data_iterator, model=self.model.get_models(), num_microbatches=num_microbatches, - seq_length=self.seq_length, + seq_length=self.seq_length if not self.use_sequence_packing else self.max_packed_len, micro_batch_size=micro_batch_size, forward_only=True, ) @@ -192,20 +199,253 @@ def _get_unpad_seqlen(self, attention_mask: torch.Tensor, pad_to_multiple_of: in return padded_max_seqlen + def _get_pad_factor(self): + # caculate pad_factor in sequence packing + cp_size = mpu.get_context_parallel_world_size() + tp_size = mpu.get_tensor_model_parallel_world_size() + pad_factor = cp_size * 2 * tp_size if cp_size > 1 else tp_size + pad_factor = math.lcm(16, pad_factor) + return pad_factor + + def _get_max_packed_len(self, micro_batches_list): + max_packed_len = -1 + for micro_batch in micro_batches_list: + input_ids = micro_batch.batch["input_ids"] + attention_mask = micro_batch.batch["attention_mask"] + + batch_size = input_ids.shape[0] + seq_lens = attention_mask.sum(dim=-1) + + pad_factor = self._get_pad_factor() + + packed_len = 0 + for b in range(batch_size): + seq_len = seq_lens[b].item() if torch.is_tensor(seq_lens[b]) else seq_lens[b] + if pad_factor > 1: + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + else: + padded_seq_len = seq_len + packed_len += padded_seq_len + + max_packed_len = max(packed_len, max_packed_len) + return max_packed_len + + def _pack_sequences(self, input_tensor, attention_mask, pad_packed_seq_to=None, pad_val=0): + """ + Pack multiple sequences into a single continuous sequence by removing padding. + + Implements sequence packing for efficient batch processing with variable-length sequences. + Removes per-sample padding and concatenates sequences while maintaining cumulative length info. + + Reference: https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/megatron/common.py + + Args: + input_tensor (torch.Tensor): Shape [batch_size, seq_len, ...], padded sequences. + attention_mask (torch.Tensor): Shape [batch_size, seq_len], 1=valid, 0=padding. + pad_packed_seq_to (int, optional): Target length for packed sequence. Defaults to None. + pad_val (int): Padding value. Defaults to 0. + + Returns: + tuple: (packed_input_tensor, packed_seq_params, cu_seqlens, cu_seqlens_padded) + - packed_input_tensor: Shape [1, total_packed_length, ...], ready for current CP rank + - packed_seq_params: PackedSeqParams with cumulative lengths and max_seqlen + - cu_seqlens: Shape [batch_size + 1], cumulative lengths of original sequences + - cu_seqlens_padded: Shape [batch_size + 1], cumulative lengths after alignment + + Note: + - Sequences padded to alignment boundaries if pad_factor > 1 or pad_packed_seq_to is set + - For CP training, sequences distributed across CP ranks + - attention_mask not needed after packing + """ + + batch_size = input_tensor.shape[0] + seq_lens = attention_mask.sum(dim=-1) + pad_factor = self._get_pad_factor() + + # Remove padding from each sequence + # Note: attention_mask is not needed in sequence packing mode + input_tensor_unpadded = [input_tensor[b][:seq_lens[b]] for b in range(batch_size)] + + # Build cumulative sequence lengths + cu_seqlens = [0] + cu_seqlens_padded = ([0] if pad_factor > 1 or pad_packed_seq_to is not None + else None + ) + + # Calculate cumulative lengths for both original and padded sequences + for b in range(batch_size): + seq_len = seq_lens[b].item() if torch.is_tensor(seq_lens[b]) else seq_lens[b] + cu_seqlens.append(cu_seqlens[-1] + seq_len) + if pad_factor > 1 or pad_packed_seq_to is not None: + # Pad sequence length to multiple of pad_factor + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + cu_seqlens_padded.append(cu_seqlens_padded[-1] + padded_seq_len) + + # Convert to tensors + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=current_platform.device_type) + if pad_factor > 1 or pad_packed_seq_to is not None: + cu_seqlens_padded = torch.tensor(cu_seqlens_padded, dtype=torch.int32, device=current_platform.device_type) + if pad_packed_seq_to is not None: + cu_seqlens_padded[-1] = pad_packed_seq_to + + # Calculate maximum sequence length + if pad_factor > 1 or pad_packed_seq_to is not None: + seq_lens_padded = cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + max_seqlen = seq_lens_padded.max().item() + else: + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + + cp_size = mpu.get_context_parallel_world_size() + + # Track running sequence length for padding + running_seq_len = 0 + if pad_factor > 1: + all_input_tensor_padded = [] + padded_tokens = [] + for b in range(batch_size): + seq_len = seq_lens[b].item() if torch.is_tensor(seq_lens[b]) else seq_lens[b] + if b == batch_size - 1 and pad_packed_seq_to is not None: + # Different from original implementation: calculate remaining length + padded_seq_len = pad_packed_seq_to - running_seq_len + else: + # Align to pad_factor boundary + padded_seq_len = ((seq_len + pad_factor - 1) // pad_factor) * pad_factor + + running_seq_len += padded_seq_len + + seq_tokens = input_tensor_unpadded[b] + + # Pad sequence if needed + if padded_seq_len > seq_len: + seq_tokens = torch.nn.functional.pad( + seq_tokens, (0, padded_seq_len - seq_len), value=pad_val + ) + all_input_tensor_padded.append(seq_tokens) + + if cp_size > 1: + # Handle Context Parallel distribution + # Add batch dimension for processing + seq_tokens_with_batch = seq_tokens.unsqueeze(0) # [1, seq_len] + seq_tokens_with_batch = self._get_feature_on_this_cp_rank( + seq_tokens_with_batch, "seq_tokens" + ) + seq_tokens = seq_tokens_with_batch.squeeze(0) # Remove batch dimension + + padded_tokens.append(seq_tokens) + + # Concatenate all sequences + packed_input_tensor = torch.cat(padded_tokens, dim=0).unsqueeze(0) + all_input_tensor_padded = torch.cat(all_input_tensor_padded, dim=0).unsqueeze(0) + + else: + # No padding factor: simply concatenate unpadded sequences + packed_input_tensor = torch.cat(input_tensor_unpadded, dim=0).unsqueeze(0) + all_input_tensor_padded = packed_input_tensor + if pad_packed_seq_to is not None: + # Pad to target length if specified + pad_len = pad_packed_seq_to - packed_input_tensor.shape[1] + if pad_len > 0: + packed_input_tensor = torch.nn.functional.pad( + packed_input_tensor, (0, pad_len), value=pad_val + ) + all_input_tensor_padded = torch.nn.functional.pad( + all_input_tensor_padded, (0, pad_len), value=pad_val + ) + + if cu_seqlens_padded is None: + cu_seqlens_padded = cu_seqlens.clone() + + # Create packed sequence parameters for attention computation + # Only use padded cumulative sequence lengths + packed_seq_params = PackedSeqParams( + cu_seqlens_q=cu_seqlens_padded, + cu_seqlens_kv=cu_seqlens_padded, + cu_seqlens_q_padded=cu_seqlens_padded, + cu_seqlens_kv_padded=cu_seqlens_padded, + # Individual sequence length + max_seqlen_q=int(max_seqlen), + max_seqlen_kv=int(max_seqlen), + qkv_format="thd", + ) + + return ( + # Packed input tensor for current rank (especially CP rank) computation + # Contains all tokens from the batch with individual sample padding/alignment preserved + packed_input_tensor.contiguous(), + + # Parameters required for sequence packing + packed_seq_params, + + # Cumulative sequence lengths of original unpadded data + cu_seqlens, + + # Cumulative sequence lengths after padding/alignment + cu_seqlens_padded, + ) + + def _get_tokens_on_this_cp_rank( + self, + input_ids: torch.Tensor, + cp_rank: int, + cp_size: int, + seq_dim: int = 1, + ) -> torch.Tensor: + """Get tokens on this context parallelism rank. + + Assumes that input_ids are already padded to a multiple of cp_size * 2 or cp_size == 1. + + Args: + input_ids: Input token IDs [seq_length, ] + cp_rank: Context parallelism rank + cp_size: Context parallelism size + + Returns: + Tokens on this context parallelism rank [1, seq_length // cp_size] + """ + if cp_size == 1: + return input_ids + + # load balance for causal attention + shard_size = input_ids.shape[seq_dim] // (cp_size * 2) + shard_inds = (cp_rank, (cp_size * 2) - cp_rank - 1) + + # Create slices for each dimension + slices = [slice(None)] * input_ids.dim() + ids_chunks = [] + + for ind in shard_inds: + slices[seq_dim] = slice(ind * shard_size, (ind + 1) * shard_size) + ids_chunks.append(input_ids[slices]) + + ids = torch.cat(ids_chunks, dim=seq_dim) + return ids + def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], model): data = next(data_iterator) input_ids = data.batch["input_ids"] attention_mask = data.batch["attention_mask"] + labels = data.batch["labels"] if "labels" in data.batch else None # labels is only used for sft + packed_seq_params = None + if self.use_remove_padding: unpad_seq_len = self._get_unpad_seqlen(attention_mask=attention_mask) input_ids = input_ids[:, :unpad_seq_len].contiguous() attention_mask = attention_mask[:, :unpad_seq_len].contiguous() - - input_ids = self._get_feature_on_this_cp_rank(input_ids, "input_ids") - attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask") - labels = data.batch["labels"] if "labels" in data.batch else None # labels is only used for sft - if labels is not None: - labels = self._get_feature_on_this_cp_rank(labels, "labels") + if self.use_sequence_packing: + input_ids, packed_seq_params, cu_seqlens, cu_seqlens_padded = self._pack_sequences( + input_ids, attention_mask, pad_packed_seq_to=self.max_packed_len + ) + if labels is not None: + labels, _, _, _ = self._pack_sequences(labels, attention_mask, pad_packed_seq_to=self.max_packed_len, + pad_val=IGNORE_INDEX) + data.meta_info['labels_packed'] = labels + attention_mask = None + else: + input_ids = self._get_feature_on_this_cp_rank(input_ids, "input_ids") + attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask") + if labels is not None: + labels = self._get_feature_on_this_cp_rank(labels, "labels") position_ids = None # attention_mask: SelfAttention defalt to te DotProductAttention with # AttnMaskType.causal in which attention_mask would not be used, pass @@ -237,9 +477,13 @@ def inner_forward_step(self, loss_func, data_iterator: Iterator[DataProto], mode forward_args.update({"force_vit_image": True}) output_tensor = model( - input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, **forward_args + input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, + packed_seq_params=packed_seq_params, **forward_args ) + if self.use_sequence_packing: + loss_func.set_packing_params(cu_seqlens=cu_seqlens, cu_seqlens_padded=cu_seqlens_padded, logger=logger) + return output_tensor, partial(loss_func, data) def broadcast_parameter(self, model_update_name, src_pp_rank, dtype, shape, parameter_name): @@ -296,148 +540,396 @@ def op_compute_entropy(self, logits: torch.Tensor, attention_mask: torch.Tensor) entropy = entropy[:, :-1] * attention_mask[:, 1:] return entropy - def op_compute_language_loss_from_logits(self, logits: torch.Tensor, targets: torch.Tensor): + def op_compute_language_loss_from_logits( + self, + logits: torch.Tensor, + targets: torch.Tensor, + reduction: str = "mean" + ): """ - Compute TP + CP aware GPT cross entropy loss. + Compute cross-entropy language modeling loss with TP and CP support. + + Handles causal next-token prediction with proper sequence boundary alignment + in distributed training scenarios. Args: - logits: [batch_size, local_seq_len, vocab_size/TP] - If CP is not enabled, local_seq_len = global_seq_len. - targets: [batch_size, global_seq_len] - Global vocab ids for target tokens. + logits (torch.Tensor): Shape [batch_size, local_seq_len, vocab_size/tp_size]. + TP-sharded (vocab) and CP-sharded (sequence). + targets (torch.Tensor): Shape [batch_size, global_seq_len]. + Global vocab IDs, padding marked with IGNORE_INDEX. + reduction (str): "mean" or "sum". Default: "mean". + Returns: - loss: scalar tensor, global average per-token loss + tuple: (loss, token_count) + - loss: Scalar tensor based on reduction method + - token_count: int64 tensor, number of valid tokens + + Sequence Alignment: + - No CP: Simple shift, logits[:, :-1] predicts targets[:, 1:] + - With CP (2 chunks/rank): Handle chunk boundaries carefully + * Chunk 0: logits[:, :chunk_size-1] → targets[:, 1:chunk_size] + * Chunk 1: logits[:, chunk_size:-1] → targets[:, chunk_size+1:] + + Note: + - vocab_parallel_cross_entropy handles TP all-reduce internally + - CP all-reduce performed explicitly for loss_sum and token_count + - Assumes 2 chunks per rank in CP mode for load balancing """ + cp_size = mpu.get_context_parallel_world_size() + + # Slice targets to current CP rank's sequence portion targets = self._get_feature_on_this_cp_rank(targets, "targets") - # shift - logits = logits[..., :-1, :].contiguous() - targets = targets[..., 1:].contiguous() - # TODO: support use remove padding - # [L_local, B, V/TP]: Megatron CE expects sequence-first layout + + if cp_size == 1: + # Simple causal shift: logits[t] predicts targets[t+1] + logits = logits[:, :-1, :].contiguous() + targets = targets[:, 1:].contiguous() + else: + # CP mode: Handle chunk boundaries with load balancing + local_seq_len = logits.size(1) + chunk_size = local_seq_len // 2 # 2 chunks per rank + + # Chunk 0: Remove last position (its target is in Chunk 1) + chunk_0_logits = logits[:, :chunk_size - 1, :] + chunk_0_targets = targets[:, 1:chunk_size] + + # Chunk 1: Remove last position and skip first target (belongs to Chunk 0) + chunk_1_logits = logits[:, chunk_size:-1, :] + chunk_1_targets = targets[:, chunk_size + 1:] + + # Merge chunks + logits = torch.cat([chunk_0_logits, chunk_1_logits], dim=1) + targets = torch.cat([chunk_0_targets, chunk_1_targets], dim=1) + + # Transpose to sequence-first layout for Megatron CE logits_tp = logits.transpose(0, 1).contiguous() labels_tp = targets.transpose(0, 1).contiguous() - # (1) Compute per-token CE loss on the local TP shard - # This function handles TP all-reduce internally to compute - # the global denominator (sum exp logits) and target logits. - # Output shape: [L_local, B] + # Compute per-token CE loss (handles TP all-reduce) loss_per_token = vocab_parallel_cross_entropy( logits_tp, labels_tp, label_smoothing=0.0 ) - # (2) Apply ignore_index mask (set loss to 0 for ignored positions) + # Apply ignore_index mask mask = (labels_tp != IGNORE_INDEX) loss_sum_local = (loss_per_token * mask).sum() token_count_local = mask.sum() - # (3) If Context Parallel is enabled, aggregate loss and token count - # across the CP group to get global values - if mpu.get_context_parallel_world_size() > 1: + # All-reduce across CP ranks + if cp_size > 1: cp_group = mpu.get_context_parallel_group() - # Stack loss sum and token count to reduce in a single communication - stats_tensor = torch.stack([loss_sum_local, token_count_local], dim=0) + stats_tensor = torch.stack([ + loss_sum_local.float(), + token_count_local.float() + ], dim=0) dist.all_reduce(stats_tensor, op=dist.ReduceOp.SUM, group=cp_group) loss_sum, token_count = stats_tensor[0], stats_tensor[1] else: - loss_sum, token_count = loss_sum_local, token_count_local + loss_sum = loss_sum_local.float() + token_count = token_count_local.float() + + # Apply reduction + if reduction == "sum": + loss = loss_sum + elif reduction == "mean": + loss = loss_sum / torch.clamp(token_count, min=1.0) + else: + raise ValueError(f"Unsupported reduction: {reduction}. Use 'mean' or 'sum'.") + + return loss, token_count.to(torch.int64) + + def op_compute_topk_logits( + self, + logits: torch.Tensor, + topk: int = 0 + ): + """ + Compute top-k logits with memory-efficient two-stage approach for TP and CP training. + + Strategy: + - topk=0: Gather full vocab across TP ranks + - topk>0: Two-stage TopK (local → gather K values → global TopK → CP gather) - # (4) Compute global average per-token loss - loss = loss_sum / token_count - return loss + Args: + logits (torch.Tensor): Shape [batch_size, local_seq_len, local_vocab_size]. + TP-sharded along vocabulary. + topk (int): 0=full vocab, >0=top-k mode. - def op_compute_logits(self, logits: torch.Tensor, tp_gather: bool = False, cp_gather: bool = False, topk: int = 0): + Returns: + tuple: (values, indices) + - topk=0: (logits [B, S, V], None) + - topk>0: (values [B, S, K], indices [B, S, K] in global vocab space) + + Note: + - Indices adjusted to global vocabulary space + - Intermediate tensors deleted early + - CP gathering after TP operations """ - Post-process logits. - - If topk == 0 (full-vocab mode), optionally gather across TP/CP ranks - using tp_gather and cp_gather flags. - If topk > 0, return top-K values and indices for each position. - - Args: - logits: [B, local_seq_len, local_vocab_size] tensor. - tp_gather: Gather full vocab across tensor-parallel ranks (only if topk==0). - cp_gather: Gather full sequence across context-parallel ranks (only if topk==0). - topk: 0 for full vocab, >0 for top-K mode. - - Returns: - (values, indices): - - full-vocab: (logits, dummy indices) - - top-K: (topk_values, topk_indices)""" - # TODO: support use remove padding - # TP gather vocab - full_logits = logits - if tp_gather: - full_logits = gather_from_tensor_model_parallel_region(full_logits) - # CP gather seq - if mpu.get_context_parallel_world_size() > 1 and cp_gather: - full_logits = context_parallel_gather(full_logits, parallel_dim=1) + + tp_size = mpu.get_tensor_model_parallel_world_size() + cp_size = mpu.get_context_parallel_world_size() + + # ========== TopK Mode: Two-Stage Memory Optimization ========== + if topk > 0: + # Stage 1: Local TopK on each TP rank's vocabulary shard + # Memory reduction: [B, local_seq, local_vocab] -> [B, local_seq, K] + local_topk_values, local_topk_indices = torch.topk( + logits, k=topk, dim=-1, sorted=False + ) + + # Adjust indices to global vocabulary space + # Each TP rank owns a contiguous vocabulary range [vocab_start, vocab_end) + vocab_start_index = mpu.get_tensor_model_parallel_rank() * logits.shape[-1] + local_topk_indices = local_topk_indices + vocab_start_index + + # Release original logits immediately to save memory + del logits + + # Stage 2: Gather local TopK results across TP ranks + # Memory: [B, local_seq, K] -> [B, local_seq, K * tp_world_size] + # Only gather K values per rank instead of full vocabulary + gathered_values = local_topk_values + gathered_indices = local_topk_indices + if tp_size > 1: + gathered_values = gather_from_tensor_model_parallel_region(local_topk_values) + gathered_indices = gather_from_tensor_model_parallel_region(local_topk_indices) + del local_topk_values, local_topk_indices + + # Stage 3: Global TopK on gathered candidates + # Select final top-k from K * tp_size candidates + # Memory: [B, local_seq, K * tp_world_size] -> [B, local_seq, K] + final_topk_values, topk_positions = torch.topk( + gathered_values, k=topk, dim=-1, sorted=True + ) + # Use topk_positions to gather corresponding global indices + final_topk_indices = torch.gather( + gathered_indices, dim=-1, index=topk_positions + ) + del gathered_values, gathered_indices, topk_positions + + # Stage 4: CP gather for sequence parallel training + if cp_size > 1: + final_topk_values = context_parallel_gather(final_topk_values, parallel_dim=1) + final_topk_indices = context_parallel_gather(final_topk_indices, parallel_dim=1) + + return final_topk_values, final_topk_indices + + # ========== Full Vocabulary Mode: Traditional Gather Path ========== + result = logits + # Gather full vocabulary across TP ranks + if tp_size > 1: + result = gather_from_tensor_model_parallel_region(result) + + # Gather across CP ranks for sequence parallelism + if cp_size > 1: + result = context_parallel_gather(result, parallel_dim=1) + + # Return full vocabulary logits if topk == 0: - batch_size = full_logits.shape[0] - # return dummy topk indices for transfer - return full_logits, torch.empty([batch_size, 1], device=logits.device) - else: - return torch.topk(full_logits, k=topk, dim=-1) + return result, None - def op_compute_prepare_cp_local_iterator(self, tensor: torch.Tensor, feature_name: str, micro_batch_size: int): + # Fallback: TopK mode without TP optimization (when TP is not used) + topk_values, topk_indices = torch.topk(result, k=topk, dim=-1) + del result + + return topk_values, topk_indices + + def op_compute_gather_by_teacher_indices( + self, + student_logits: torch.Tensor, + teacher_indices: torch.Tensor + ): """ - Prepare a microbatch iterator for a tensor that may require Context Parallel (CP) slicing. + Gather student logits at teacher indices with TP support via sparse gather. - Notes: - - If the input `tensor` is None, this function returns None. - - The input tensor is assumed to have shape [global_batch, global_seq_len, ...], - with batch dimension first and sequence dimension second. - - When CP size > 1, the sequence dimension (dim=1) is sliced to the local CP rank - using `_get_feature_on_this_cp_rank`. - - After CP slicing (if any), the tensor is split along the batch dimension (dim=0) - into microbatches of size `micro_batch_size`. + Strategy: + - No TP: Direct torch.gather + - TP mode: Sparse gather + all-reduce + 1. Mask indices belonging to local vocab shard + 2. Gather local values, zero out non-local + 3. All-reduce sum across TP ranks Args: - tensor (torch.Tensor or None): Full tensor before CP slicing. Can be None. - feature_name (str): Identifier passed to `_get_feature_on_this_cp_rank` for CP slicing. - e.g., "teacher_logits" or "teacher_topk_indices". - micro_batch_size (int): Number of samples per microbatch; splitting is done along dim=0. + student_logits (torch.Tensor): Shape [batch_size, seq_len, local_vocab_size]. + TP-sharded along vocabulary. + teacher_indices (torch.Tensor): Shape [batch_size, seq_len, k] or [batch_size, seq_len]. + Global vocabulary indices (not sharded). Returns: - iterator or None: Iterator over microbatches with shape - [micro_batch_size, local_seq_len, ...]. - Returns None if input `tensor` is None. + torch.Tensor: Gathered logits matching teacher_indices shape. + + Note: + - Returns original logits if teacher_indices is None + - Handles 2D/3D indices, restores original shape + - Vocab range per rank: [tp_rank * local_vocab_size, (tp_rank+1) * local_vocab_size) """ - if tensor is None: - return None - # CP slicing on sequence dimension if enabled - if mpu.get_context_parallel_world_size() > 1: - local_tensor = self._get_feature_on_this_cp_rank(tensor, feature_name) - else: - local_tensor = tensor + # Early return if no teacher indices provided + if teacher_indices is None: + return student_logits + + # Ensure indices are long type for indexing + if teacher_indices.dtype != torch.long: + teacher_indices = teacher_indices.long() - # Microbatch split along batch dimension - return iter(local_tensor.split(micro_batch_size, dim=0)) + # Handle 2D input by adding dimension (will be removed before return) + squeeze_output = False + if teacher_indices.dim() == 2: + teacher_indices = teacher_indices.unsqueeze(-1) + squeeze_output = True - def op_compute_various_divergence(self, loss_callable, logits, teacher_logits, teacher_topk_indices, - labels, attention_mask=None): + tp_world_size = mpu.get_tensor_model_parallel_world_size() + + # Non-TP mode: Direct gather operation + if tp_world_size == 1: + gathered = torch.gather(student_logits, dim=-1, index=teacher_indices) + return gathered.squeeze(-1) if squeeze_output else gathered + + # ========== TP-Sharded Sparse Gather ========== + tp_rank = mpu.get_tensor_model_parallel_rank() + local_vocab_size = student_logits.shape[-1] + + # Calculate vocabulary range owned by current TP rank + vocab_start = tp_rank * local_vocab_size + vocab_end = vocab_start + local_vocab_size + + # Create mask for indices that belong to local vocabulary shard + local_mask = (teacher_indices >= vocab_start) & (teacher_indices < vocab_end) + + # Convert global indices to local vocabulary space + # Clamp to valid range to avoid index errors (non-local indices will be masked out) + local_indices = teacher_indices - vocab_start + local_indices = torch.clamp(local_indices, 0, local_vocab_size - 1) + + # Gather values from local vocabulary shard + local_gathered = torch.gather(student_logits, dim=-1, index=local_indices) + + # Mask out values that don't belong to local vocabulary + # Non-local positions are set to zero (will not contribute to final sum) + local_gathered = torch.where(local_mask, local_gathered, torch.zeros_like(local_gathered)) + + # All-reduce sum across TP ranks (fully differentiable) + # Forward: Sum contributions from all ranks (only one rank contributes non-zero per index) + # Backward: Each rank receives full gradient, but only masked portion affects local parameters + gathered = reduce_from_tensor_model_parallel_region(local_gathered) + + # Restore original shape if input was 2D + return gathered.squeeze(-1) if squeeze_output else gathered + + def op_compute_various_divergence( + self, + loss_callable, logits, teacher_topk_probs, teacher_topk_log_probs, teacher_topk_indices, + teacher_topk_inf_mask, labels, attention_mask=None, reduction="mean" + ): """ - Note: - `logits` here are both TP (Tensor Parallel) and CP (Context Parallel) sharded. - `logits` here are CP (Context Parallel) sharded. - - We gather across TP to get full-vocab logits for the local CP sequence slice. - `labels`, and `attention_mask` are provided as full tensors - (global sequence length). These are then sliced down to the local CP rank's - sequence shard before loss computation. - """ - # TODO: support TP and remove padding - full_logits = gather_from_tensor_model_parallel_region(logits) - full_logits = self.op_compute_gather_by_teacher_indices(full_logits, teacher_topk_indices) - if teacher_logits.shape[-1] != full_logits.shape[-1]: - teacher_logits = teacher_logits[:, :, : min(full_logits.shape[-1], teacher_logits.shape[-1])] - labels = self._get_feature_on_this_cp_rank(labels, "labels") - if attention_mask is not None: - attention_mask = self._get_feature_on_this_cp_rank(attention_mask, "attention_mask") - loss = loss_callable(logits=full_logits, teacher_logits=teacher_logits, labels=labels, attention_mask=attention_mask) - return loss + Compute divergence losses (KL, JSD, RKL, etc.) with TP and CP support. + + Strategy: + 1. Slice teacher outputs to current CP rank's sequence + 2. Gather student logits at teacher's top-k indices (TP-aware) + 3. Compute per-token divergence loss + 4. Gather loss across CP ranks + 5. Apply padding mask and reduction + + Args: + loss_callable (callable): Divergence function (KL/JSD/RKL). + Takes: logits, teacher_probs, teacher_log_probs, teacher_inf_mask. + logits (torch.Tensor): Shape [batch_size, local_seq_len, local_vocab_size]. + TP and CP sharded. + teacher_topk_probs (torch.Tensor): Shape [batch_size, global_seq_len, topk]. + Full tensor (not sharded). + teacher_topk_log_probs (torch.Tensor): Shape [batch_size, global_seq_len, topk]. + teacher_topk_indices (torch.Tensor): Shape [batch_size, global_seq_len, topk]. + Global vocabulary indices. + teacher_topk_inf_mask (torch.Tensor): Shape [batch_size, global_seq_len, topk]. + labels (torch.Tensor): Shape [batch_size, global_seq_len]. + Padding marked with IGNORE_INDEX. + attention_mask (torch.Tensor, optional): Shape [batch_size, global_seq_len]. + 0=padding. Used if labels is None. + reduction (str): "mean", "sum", or "none". + + Returns: + tuple: (loss, token_count) + - loss: Scalar (mean/sum) or tensor [B, S] (none) + - token_count: Scalar, number of valid tokens + + Note: + - Teacher outputs sliced to CP rank's sequence + - Student logits TP-sharded, handled by sparse gather + - Token count from full sequence for correct normalization + """ + + # Preserve full tensors for final mask computation + labels_full = labels + attention_mask_full = attention_mask + + # (1) Slice teacher outputs to current CP rank's sequence portion + # Each CP rank processes a contiguous chunk of the sequence + if teacher_topk_probs is not None: + teacher_topk_probs = self._get_feature_on_this_cp_rank(teacher_topk_probs, "teacher_topk_probs") + if teacher_topk_indices is not None: + teacher_topk_indices = self._get_feature_on_this_cp_rank(teacher_topk_indices, "teacher_topk_indices") + if teacher_topk_log_probs is not None: + teacher_topk_log_probs = self._get_feature_on_this_cp_rank(teacher_topk_log_probs,"teacher_topk_log_probs") + if teacher_topk_inf_mask is not None: + teacher_topk_inf_mask = self._get_feature_on_this_cp_rank(teacher_topk_inf_mask, "teacher_topk_inf_mask") + + # (2) Gather student logits at teacher's top-k indices + # Handles TP-sharded logits with sparse gather operation + # Input: [batch_size, local_seq_len, local_vocab_size] (TP-sharded) + # Output: [batch_size, local_seq_len, topk] (aligned with teacher indices) + full_logits = self.op_compute_gather_by_teacher_indices(logits, teacher_topk_indices) + + # (3) Compute per-token divergence loss + # loss_callable computes divergence (e.g., KL, JSD) between student and teacher distributions + # Returns: [batch_size, local_seq_len] per-token loss + kld_per_token = loss_callable( + logits=full_logits, + teacher_probs=teacher_topk_probs, + teacher_log_probs=teacher_topk_log_probs, + teacher_inf_mask=teacher_topk_inf_mask, + ) + + # (4) Gather per-token loss across CP ranks to restore full sequence + # Input: [batch_size, local_seq_len] (CP-sharded sequence) + # Output: [batch_size, global_seq_len] (full sequence) + cp_size = mpu.get_context_parallel_world_size() + if cp_size > 1: + kld_per_token = context_parallel_gather(kld_per_token, parallel_dim=1) + + # (5) Compute total number of valid (non-padded) tokens + # Uses full labels/attention_mask to count across entire batch + if labels_full is not None: + # Padding positions marked with IGNORE_INDEX in labels + pad_mask = labels_full.eq(IGNORE_INDEX) + else: + # Alternatively use attention_mask where 0 indicates padding + pad_mask = attention_mask_full.eq(0) + token_count = (~pad_mask).sum().float() + + # (6) Early return for 'none' reduction (per-token loss) + if reduction == 'none': + return kld_per_token, token_count + + # (7) Apply padding mask and compute aggregated loss + # Mask out padding positions by setting their loss to 0 + kld_masked = kld_per_token.masked_fill_(pad_mask, 0.0) + loss_sum = kld_masked.sum() + + # (8) Return loss based on reduction method + if reduction == "sum": + # Return sum of loss over all valid tokens + return loss_sum, token_count + elif reduction == "mean": + # Return average loss per valid token + # Clamp token_count to avoid division by zero + return loss_sum / token_count.clamp(min=1.0), token_count + else: + raise ValueError(f"Unsupported reduction: {reduction}. Use 'mean', 'sum', or 'none'.") def op_compute_language_loss(self, losses: torch.Tensor, labels: torch.Tensor): - labels = self._get_feature_on_this_cp_rank(labels, "labels") + if not self.use_sequence_packing: + labels = self._get_feature_on_this_cp_rank(labels, "labels") loss_mask = (labels != IGNORE_INDEX).float() loss_mask = loss_mask.view(-1).float() @@ -585,7 +1077,7 @@ def train_step(self, batch: DataProto, loss_func: Callable): is_offload_optimizer_states_in_train_step = batch.meta_info.get("is_offload_optimizer_states_in_train_step", True) if self.worker_config.use_dynamic_batching_in_train: - data_iterator = make_micro_batch_iter_for_dynamic_batching(batch) + micro_batches_list = list(make_micro_batch_iter_for_dynamic_batching(batch)) num_microbatches = batch.meta_info["num_micro_batchs"] mini_batch_size = 1 else: @@ -594,16 +1086,20 @@ def train_step(self, batch: DataProto, loss_func: Callable): assert ( num_microbatches == self.megatron_train_args.gradient_accumulation_steps ), f"num_microbatches={num_microbatches} gradient_accumulation_steps={self.megatron_train_args.gradient_accumulation_steps}" - data_iterator = [ - batch.make_iterator(mini_batch_size=mini_batch_size, epochs=1) for _ in range(len(self.model)) - ] + micro_batches_list = batch.chunk(chunks=num_microbatches) + if self.use_sequence_packing: + mini_batch_size = 1 + self.max_packed_len = self._get_max_packed_len(micro_batches_list) + logger.info(f"max_packed_len: {self.max_packed_len}") + + data_iterator = [iter(micro_batches_list) for _ in range(len(self.model))] metrics_tensors: List[Dict[str, "torch.Tensor"]] = self.forward_backward_func( forward_step_func=partial(self.inner_forward_step, loss_func), data_iterator=data_iterator, model=self.model.get_models(), num_microbatches=num_microbatches, - seq_length=self.seq_length, + seq_length=self.seq_length if not self.use_sequence_packing else self.max_packed_len, micro_batch_size=mini_batch_size, forward_only=False, ) diff --git a/roll/distributed/strategy/strategy.py b/roll/distributed/strategy/strategy.py index 6d52d85e9..d13f675c9 100644 --- a/roll/distributed/strategy/strategy.py +++ b/roll/distributed/strategy/strategy.py @@ -188,80 +188,125 @@ def op_compute_language_loss_from_logits(self, logits: torch.Tensor, targets: to targets.view(-1), ignore_index=IGNORE_INDEX ) - return loss + mask = (targets != IGNORE_INDEX) + valid_tokens = mask.sum() + return loss, valid_tokens - def op_compute_logits(self, logits: torch.Tensor, tp_gather: bool = False, cp_gather: bool = False, topk: int = 0): + def op_compute_topk_logits(self, logits: torch.Tensor, topk: int = 0): + """ + Compute top-k logits from the input logits tensor. + + Args: + logits (torch.Tensor): Input logits tensor of shape [batch_size, ..., vocab_size]. + topk (int): Number of top elements to select. If 0, returns original logits. + + Returns: + tuple: + - If topk == 0: (original logits, empty tensor of shape [batch_size, 1]) + - Otherwise: (top-k logits, top-k indices) from torch.topk """ - Post-process logits. - - If topk == 0 (full-vocab mode), optionally gather across TP/CP ranks - using tp_gather and cp_gather flags. - If topk > 0, return top-K values and indices for each position. - - Args: - logits: [B, local_seq_len, local_vocab_size] tensor. - tp_gather: Gather full vocab across tensor-parallel ranks (only if topk==0). - cp_gather: Gather full sequence across context-parallel ranks (only if topk==0). - topk: 0 for full vocab, >0 for top-K mode. - - Returns: - (values, indices): - - full-vocab: (logits, dummy indices) - - top-K: (topk_values, topk_indices)""" if topk == 0: batch_size = logits.shape[0] return logits, torch.empty([batch_size, 1], device=logits.device) else: return torch.topk(logits, k=topk, dim=-1) - def op_compute_prepare_cp_local_iterator(self, tensor: torch.Tensor, feature_name: str, micro_batch_size: int): + def op_compute_topk_probs_and_indices(self, logits: torch.Tensor, topk: int = 0, target_vocab_size: int = None, + kd_temperature: int = 1, teacher_temperature: int = 1): """ - Prepare a microbatch iterator for a tensor that may require Context Parallel (CP) slicing. - - Notes: - - If the input `tensor` is None, this function returns None. - - The input tensor is assumed to have shape [global_batch, global_seq_len, ...], - with batch dimension first and sequence dimension second. - - When CP size > 1, the sequence dimension (dim=1) is sliced to the local CP rank - using `_get_feature_on_this_cp_rank`. - - After CP slicing (if any), the tensor is split along the batch dimension (dim=0) - into microbatches of size `micro_batch_size`. + Compute top-k probabilities, log probabilities, and indices from logits with temperature scaling. Args: - tensor (torch.Tensor or None): Full tensor before CP slicing. Can be None. - feature_name (str): Identifier passed to `_get_feature_on_this_cp_rank` for CP slicing. - e.g., "teacher_logits" or "teacher_topk_indices". - micro_batch_size (int): Number of samples per microbatch; splitting is done along dim=0. + logits (torch.Tensor): Input logits tensor of shape [batch_size, seq_len, vocab_size]. + topk (int): Number of top elements to select. If 0, uses all logits. + target_vocab_size (int, optional): Target vocabulary size to truncate logits. Defaults to None. + kd_temperature (int): Knowledge distillation temperature for scaling. Defaults to 1. + teacher_temperature (int): Teacher model temperature for scaling. Defaults to 1. Returns: - iterator or None: Iterator over microbatches with shape - [micro_batch_size, local_seq_len, ...]. - Returns None if input `tensor` is None. + tuple: (topk_probs, topk_log_probs, topk_indices, topk_inf_mask) + - topk_probs (torch.Tensor): Softmax probabilities of top-k logits. + - topk_log_probs (torch.Tensor): Log softmax probabilities of top-k logits. + - topk_indices (torch.Tensor): Indices of top-k elements. + - topk_inf_mask (torch.Tensor): Boolean mask indicating infinite values in top-k logits. + """ + if target_vocab_size is not None and logits.shape[-1] != target_vocab_size: + logits = logits[:, :, : min(logits.shape[-1], target_vocab_size)] + logits = logits / kd_temperature + logits = logits / teacher_temperature + topk_logits, topk_indices = self.op_compute_topk_logits(logits, topk) + topk_inf_mask = topk_logits.isinf() + topk_probs = F.softmax(topk_logits, dim=-1, dtype=torch.float32) + topk_log_probs = F.log_softmax(topk_logits, dim=-1) + return topk_probs, topk_log_probs, topk_indices, topk_inf_mask + + def op_compute_various_divergence(self, loss_callable, logits, teacher_topk_probs, teacher_topk_log_probs, + teacher_topk_indices, + teacher_topk_inf_mask, labels, attention_mask=None, reduction="mean"): """ - if tensor is None: - return None + Compute divergence loss between student and teacher distributions with support for distributed training. - # Microbatch split along batch dimension - return iter(tensor.split(micro_batch_size, dim=0)) + This function handles both Tensor Parallel (TP) and Context Parallel (CP) sharded logits, gathering + full vocabulary logits for the local sequence slice before computing the divergence loss. - def op_compute_various_divergence(self, loss_callable, logits, teacher_logits, teacher_topk_indices, - labels, attention_mask=None): - """ - Note: - `logits` here are both TP (Tensor Parallel) and CP (Context Parallel) sharded. - `logits` here are CP (Context Parallel) sharded. - - We gather across TP to get full-vocab logits for the local CP sequence slice. - `labels`, and `attention_mask` are provided as full tensors - (global sequence length). These are then sliced down to the local CP rank's - sequence shard before loss computation. - """ + Args: + loss_callable (callable): Loss function that computes divergence between student and teacher. + logits (torch.Tensor): Student model logits, potentially TP/CP sharded. Shape: [batch_size, seq_len, vocab_size]. + teacher_topk_probs (torch.Tensor): Teacher's top-k probabilities. + teacher_topk_log_probs (torch.Tensor): Teacher's top-k log probabilities. + teacher_topk_indices (torch.Tensor): Indices of teacher's top-k elements. + teacher_topk_inf_mask (torch.Tensor): Mask for infinite values in teacher's top-k logits. + labels (torch.Tensor, optional): Ground truth labels with padding marked as IGNORE_INDEX. Defaults to None. + attention_mask (torch.Tensor, optional): Attention mask where 0 indicates padding. Used if labels is None. + reduction (str): Reduction method - "mean", "sum", or "none". Defaults to "mean". + Returns: + tuple: (loss, token_count) + - loss (torch.Tensor): Computed loss value based on reduction method. + - "mean": averaged loss over valid tokens + - "sum": summed loss over valid tokens + - "none": per-token loss tensor + - token_count (torch.Tensor): Number of valid (non-padded) tokens. + + Raises: + ValueError: If reduction method is not one of "mean", "sum", or "none". + + Note: + - Input `logits` are both TP (Tensor Parallel) and CP (Context Parallel) sharded. + - The function gathers logits across TP to obtain full-vocab logits for the local CP sequence slice. + - `labels` and `attention_mask` are provided as full tensors with global sequence length, + then sliced to the local CP rank's sequence shard during loss computation. + """ + # Gather full vocabulary logits using teacher's top-k indices full_logits = logits full_logits = self.op_compute_gather_by_teacher_indices(full_logits, teacher_topk_indices) - if teacher_logits.shape[-1] != full_logits.shape[-1]: - teacher_logits = teacher_logits[:, :, : min(full_logits.shape[-1], teacher_logits.shape[-1])] - loss = loss_callable(logits=full_logits, teacher_logits=teacher_logits, labels=labels, attention_mask=attention_mask) - return loss + + # Compute per-token divergence loss + kld_per_token = loss_callable(logits=full_logits, teacher_probs=teacher_topk_probs, + teacher_log_probs=teacher_topk_log_probs, + teacher_inf_mask=teacher_topk_inf_mask) + + # Create padding mask from labels or attention mask + if labels is not None: + pad_mask = labels.eq(IGNORE_INDEX) + else: + pad_mask = attention_mask.eq(0) + token_count = (~pad_mask).sum().float() + + # Early return for 'none' reduction (per-token loss) + if reduction == 'none': + return kld_per_token, token_count + + # Apply mask and compute aggregated loss + kld_masked = kld_per_token.masked_fill_(pad_mask, 0.0) + loss_sum = kld_masked.sum() + + if reduction == "sum": + return loss_sum, token_count + elif reduction == "mean": + return loss_sum / token_count.clamp(min=1.0), token_count + else: + raise ValueError(f"Unsupported reduction: {reduction}. Use 'mean', 'sum', or 'none'.") # Both megatron and deepspeed can output language loss directly. # This op is mainly for computing context-parallel loss. diff --git a/roll/pipeline/agentic/agentic_config.py b/roll/pipeline/agentic/agentic_config.py index affbb81dd..fc2e53cbb 100644 --- a/roll/pipeline/agentic/agentic_config.py +++ b/roll/pipeline/agentic/agentic_config.py @@ -236,6 +236,8 @@ def __post_init__(self): logger.info(f"val_env_manager.max_traj_per_env: {self.val_env_manager.max_traj_per_env}") assert self.val_env_manager.max_traj_per_env >= traj_per_env, f"max_traj_per_env must be >= {traj_per_env}" + self.validate_worker_config() + def make_env_configs(self, env_manager_config: EnvManagerConfig): # construct env configs env_configs = defaultdict(defaultdict) diff --git a/roll/pipeline/distill/distill_config.py b/roll/pipeline/distill/distill_config.py index ade85c563..c9e51a4eb 100644 --- a/roll/pipeline/distill/distill_config.py +++ b/roll/pipeline/distill/distill_config.py @@ -97,11 +97,6 @@ class DistillConfig(BaseConfig): metadata={"help": "Whether to distill on the prompt or not."}, ) - max_length: Optional[int] = field( - default=4096, - metadata={"help": "Max length for DataCollator."} - ) - max_grad_norm: Optional[float] = field( default=0, metadata={"help": "Maximum grad norm"} @@ -135,6 +130,10 @@ def __post_init__(self): self.teacher.name = "teacher" self.student.name = "student" + self.target_vocab_size = None + + self.validate_worker_config() + def to_dict(self): return dataclasses.asdict(self) diff --git a/roll/pipeline/distill/distill_pipeline.py b/roll/pipeline/distill/distill_pipeline.py index 3cd6fa801..c49cd2b7b 100644 --- a/roll/pipeline/distill/distill_pipeline.py +++ b/roll/pipeline/distill/distill_pipeline.py @@ -183,6 +183,7 @@ def __init__(self, pipeline_config: DistillConfig): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = "right" # padding should be on right in distill + pipeline_config.target_vocab_size = self.tokenizer.vocab_size dataset = preprocess_dataset( dataset, diff --git a/roll/pipeline/distill/distill_worker.py b/roll/pipeline/distill/distill_worker.py index 3e98af82a..b86cc7079 100644 --- a/roll/pipeline/distill/distill_worker.py +++ b/roll/pipeline/distill/distill_worker.py @@ -33,13 +33,20 @@ def __init__(self, worker_config: WorkerConfig): self.tokenizer = None self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None self.kl_loss_func = None - self.logits_cache = LogitsCache(self.logger) + self.probs_cache = LogitsCache(self.logger) + self.log_probs_cache = LogitsCache(self.logger) self.topk_indices_cache = LogitsCache(self.logger) - self.tensor_name_to_cache_name = {"logits": "logits_cache", "topk_indices": "topk_indices_cache"} - self.teacher_logits = None + self.inf_mask_cache = LogitsCache(self.logger) + self.tensor_name_to_cache_name = {"topk_probs": "probs_cache", "topk_log_probs": "log_probs_cache", + "topk_indices": "topk_indices_cache", "topk_inf_mask": "inf_mask_cache"} + self.teacher_probs = None + self.teacher_log_probs = None self.teacher_topk_indices = None - self.teacher_logits_iterator = None + self.teacher_inf_mask = None + self.teacher_probs_iterator = None + self.teacher_log_probs_iterator = None self.teacher_topk_indices_iterator = None + self.teacher_inf_mask_iterator = None @register(dispatch_mode=Dispatch.ONE_TO_ALL) def initialize(self, pipeline_config): @@ -70,17 +77,19 @@ def train_step(self, data: DataProto): metrics = {} micro_batch_size = self.worker_config.training_args.per_device_train_batch_size - # Retrieve the teacher logits slice for the current CP rank + # Retrieve the teacher logits if self.rank_info.is_pipeline_last_stage: - self.teacher_logits = self.logits_cache.pop_full_logits() - self.teacher_logits_iterator = self.strategy.op_compute_prepare_cp_local_iterator(self.teacher_logits, - "teacher_logits", micro_batch_size) - # Retrieve the teacher_topk_indices slice for the current CP rank + self.teacher_probs = self.probs_cache.pop_full_logits() + self.teacher_probs_iterator = iter(self.teacher_probs.split(micro_batch_size, dim=0)) + self.teacher_log_probs = self.log_probs_cache.pop_full_logits() + self.teacher_log_probs_iterator = iter(self.teacher_log_probs.split(micro_batch_size, dim=0)) + # Retrieve the teacher_topk_indices if self.rank_info.is_pipeline_last_stage: self.teacher_topk_indices = self.topk_indices_cache.pop_full_logits() if self.pipeline_config.logits_topk != 0: - self.teacher_topk_indices_iterator = self.strategy.op_compute_prepare_cp_local_iterator(self.teacher_topk_indices, - "teacher_topk_indices", micro_batch_size) + self.teacher_topk_indices_iterator = iter(self.teacher_topk_indices.split(micro_batch_size, dim=0)) + self.teacher_inf_mask = self.inf_mask_cache.pop_full_logits() + self.teacher_inf_mask_iterator = iter(self.teacher_inf_mask.split(micro_batch_size, dim=0)) self.logger.info(f"is_offload_states: {is_offload_states}") with state_offload_manger( strategy=self.strategy, @@ -99,7 +108,12 @@ def train_step(self, data: DataProto): backward_batch_size = ( per_device_train_batch_size * self.worker_config.training_args.gradient_accumulation_steps ) - student_metrics = self.strategy.train_step(batch=data, loss_func=self.loss_func) + + loss_func = self.loss_func + if self.worker_config.use_sequence_packing: + from roll.utils.sequence_packing import SequencePackingDistillLossWrapper + loss_func = SequencePackingDistillLossWrapper(self.strategy, loss_func) + student_metrics = self.strategy.train_step(batch=data, loss_func=loss_func) append_to_dict(metrics, student_metrics) data.to("cpu") @@ -116,24 +130,33 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): output_tensor: torch.Tensor, the tensor returned by model.forward() """ - student_logits, _ = self.strategy.op_compute_logits(output_tensor, tp_gather=False, cp_gather=False, topk=0) + student_logits = output_tensor labels = data.batch['labels_for_loss'] - attention_mask = data.batch['attention_mask'] # language loss - gpt_loss = self.strategy.op_compute_language_loss_from_logits(student_logits, labels) + gpt_loss, _ = self.strategy.op_compute_language_loss_from_logits(student_logits, labels) # distill loss - if self.teacher_logits_iterator is not None: - teacher_logits = next(self.teacher_logits_iterator) + if self.teacher_probs_iterator is not None: + teacher_probs = next(self.teacher_probs_iterator) + else: + teacher_probs = None + if self.teacher_log_probs_iterator is not None: + teacher_log_probs = next(self.teacher_log_probs_iterator) else: - teacher_logits = None + teacher_log_probs = None if self.teacher_topk_indices_iterator is not None: teacher_topk_indices = next(self.teacher_topk_indices_iterator) else: teacher_topk_indices = None - distill_loss = self.strategy.op_compute_various_divergence(self.kl_loss_func, student_logits, teacher_logits, - teacher_topk_indices, labels, attention_mask) + if self.teacher_inf_mask_iterator is not None: + teacher_inf_mask = next(self.teacher_inf_mask_iterator) + else: + teacher_inf_mask = None + + distill_loss, _ = self.strategy.op_compute_various_divergence(self.kl_loss_func, student_logits, teacher_probs, + teacher_log_probs, teacher_topk_indices, teacher_inf_mask + , labels, attention_mask=None,) loss = ((1 - self.pipeline_config.distill_loss_weight) * gpt_loss + self.pipeline_config.distill_loss_weight * distill_loss) @@ -409,9 +432,11 @@ def __init__(self, worker_config: WorkerConfig): super().__init__(worker_config=worker_config) self.tokenizer = None self.strategy: Optional[Union[InferenceStrategy, TrainStrategy]] = None - # Store the output logits to prevent their GPU memory from being released. - self.logits = None + # Store the output tensors to prevent their GPU memory from being released. + self.topk_probs = None + self.topk_log_probs = None self.topk_indices = None + self.topk_inf_mask = None @register(dispatch_mode=Dispatch.ONE_TO_ALL) def initialize(self, pipeline_config): @@ -430,16 +455,22 @@ def initialize(self, pipeline_config): self.strategy.offload_states() + def get_tensor_name_list_for_transfer(self): + return ['topk_probs', 'topk_log_probs', 'topk_indices', 'topk_inf_mask'] + def forward_func(self, data: DataProto, output_tensor: torch.Tensor, non_loss_data: bool = True): - teacher_logits, teacher_indices = self.strategy.op_compute_logits( + topk_probs, topk_log_probs, topk_indices, topk_inf_mask = self.strategy.op_compute_topk_probs_and_indices( output_tensor, - tp_gather=True, - cp_gather=True, - topk=self.pipeline_config.logits_topk + topk=self.pipeline_config.logits_topk, + target_vocab_size=self.pipeline_config.target_vocab_size, + kd_temperature=self.pipeline_config.kd_temperature, + teacher_temperature=self.pipeline_config.teacher_temperature ) return torch.tensor(0., device=output_tensor.device), { - 'logits': teacher_logits.detach(), - 'topk_indices': teacher_indices.detach() + 'topk_probs': topk_probs.detach(), + 'topk_log_probs': topk_log_probs.detach(), + 'topk_indices': topk_indices.detach(), + 'topk_inf_mask': topk_inf_mask.detach() } @register(dispatch_mode=Dispatch.DP_MP_DISPATCH_FIRST_COLLECT_ALL, clear_cache=False) @@ -465,13 +496,23 @@ def forward(self, data: DataProto): data.meta_info["output_on_all_tp_cp_ranks"] = True self.logger.info(f"global_step: {data.meta_info.get('global_step', 0)}") + + forward_func = self.forward_func + if self.worker_config.use_sequence_packing: + from roll.utils.sequence_packing import SequencePackingDistillForwardWrapper + forward_func = SequencePackingDistillForwardWrapper(self.strategy, forward_func) + with torch.no_grad(): - forward_output = self.strategy.forward_step(batch=data, forward_func=self.forward_func) - self.logits = None + forward_output = self.strategy.forward_step(batch=data, forward_func=forward_func) + self.topk_probs = None + self.topk_log_probs = None self.topk_indices = None + self.topk_inf_mask = None if forward_output: - self.logits = forward_output['logits'] + self.topk_probs = forward_output['topk_probs'] + self.topk_log_probs = forward_output['topk_log_probs'] self.topk_indices = forward_output['topk_indices'] + self.topk_inf_mask = forward_output['topk_inf_mask'] output = DataProto(meta_info={"metrics": metrics}).to("cpu") diff --git a/roll/pipeline/distill/logits_transfer_group.py b/roll/pipeline/distill/logits_transfer_group.py index 9f26af080..3b1ab079a 100644 --- a/roll/pipeline/distill/logits_transfer_group.py +++ b/roll/pipeline/distill/logits_transfer_group.py @@ -41,7 +41,8 @@ def __init__(self, src_cluster, tgt_cluster, backend="ipc+nccl"): self.tgt_cluster = tgt_cluster self.backend = backend - self.tensor_name_list_for_transfer = ['logits', 'topk_indices'] + # get tensor list from src cluster + self.tensor_name_list_for_transfer = ray.get(src_cluster.workers[0].get_tensor_name_list_for_transfer.remote()) self.broadcast_comm_pan = defaultdict(lambda: defaultdict(list)) self.p2p_comm_plan = defaultdict(lambda: defaultdict(list)) diff --git a/roll/pipeline/distill/various_divergence.py b/roll/pipeline/distill/various_divergence.py index 0ae76dc7b..4bd06fe53 100644 --- a/roll/pipeline/distill/various_divergence.py +++ b/roll/pipeline/distill/various_divergence.py @@ -8,7 +8,6 @@ class VariousDivergence: def __init__(self, pipeline_config: DistillConfig, padding_id=IGNORE_INDEX) -> None: self.kd_temperature = pipeline_config.kd_temperature - self.teacher_temperature = pipeline_config.teacher_temperature self.kd_objective = pipeline_config.kd_objective self.padding_id = padding_id self.args = pipeline_config @@ -28,100 +27,54 @@ def __init__(self, pipeline_config: DistillConfig, padding_id=IGNORE_INDEX) -> N else: raise NameError(f"Unsupported kd_objective for `{self.kd_objective}'") - def __call__(self, logits, teacher_logits, labels, attention_mask=None): - kd_loss = self.dist_func(logits, teacher_logits, labels,attention_mask=attention_mask) - return kd_loss + def __call__(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask): + kld = self.dist_func(logits, teacher_probs, teacher_log_probs, teacher_inf_mask) + return kld def compute_forward_kl_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean" + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): - logits = logits / self.kd_temperature - teacher_logits = teacher_logits / self.kd_temperature - teacher_logits = teacher_logits / self.teacher_temperature - lprobs = torch.log_softmax(logits, -1, dtype=torch.float32) - teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) - teacher_lprobs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) - kld = (teacher_probs * (teacher_lprobs - lprobs)) - inf_mask = logits.isinf() - kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - kld = kld.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - kld = kld.sum() / num_valid_elements + kld = (teacher_probs * (teacher_log_probs - lprobs)) + inf_mask = logits.isinf() | teacher_inf_mask + kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) return kld def compute_reverse_kl_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean", + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): logits = logits / self.kd_temperature - teacher_logits = teacher_logits / self.kd_temperature - teacher_logits = teacher_logits / self.teacher_temperature probs = torch.softmax(logits, -1, dtype=torch.float32) lprobs = torch.log_softmax(logits, -1, dtype=torch.float32) - teacher_lprobs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) - kld = (probs * (lprobs - teacher_lprobs)) - inf_mask = logits.isinf() | teacher_logits.isinf() + kld = (probs * (lprobs - teacher_log_probs)) + inf_mask = logits.isinf() | teacher_inf_mask kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - kld = kld.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - kld = kld.sum() / num_valid_elements - return kld def compute_adaptive_kl_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean" + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): alpha = self.args.adaptive_kl_alpha probs = torch.softmax( logits / self.kd_temperature, dim=-1, dtype=torch.float32 ) - teacher_probs = torch.softmax( - teacher_logits / self.teacher_temperature / self.kd_temperature, dim=-1, dtype=torch.float32 - ) sorted_teacher_probs, sorted_idx = teacher_probs.sort(-1) sorted_probs = probs.gather(-1, sorted_idx) gap = (sorted_teacher_probs - sorted_probs).abs() @@ -130,148 +83,71 @@ def compute_adaptive_kl_divergence( g_head = (gap * (1 - tail_mask)).sum(-1).detach() g_tail = (gap * tail_mask).sum(-1).detach() - fkl = self.compute_forward_kl_divergence(logits, teacher_logits, target, attention_mask=attention_mask, reduction="none") - rkl = self.compute_reverse_kl_divergence(logits, teacher_logits, target, attention_mask=attention_mask, reduction="none") + fkl = self.compute_forward_kl_divergence(logits, teacher_probs, teacher_log_probs, teacher_inf_mask) + rkl = self.compute_reverse_kl_divergence(logits, teacher_probs, teacher_log_probs, teacher_inf_mask) akl = (g_head / (g_head + g_tail)) * fkl + (g_tail / (g_head + g_tail)) * rkl - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - akl = akl.masked_fill_(pad_mask, 0.0) - akl = akl.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - akl = akl.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - akl = akl.sum() / num_valid_elements - return akl def compute_skewed_forward_kl_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean" + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): logits = logits / self.kd_temperature - teacher_logits = teacher_logits / self.kd_temperature - teacher_logits = teacher_logits / self.teacher_temperature student_probs = torch.softmax(logits, -1, dtype=torch.float32) - teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) mixed_probs = self.args.skew_lambda * teacher_probs + (1 - self.args.skew_lambda) * student_probs mixed_lprobs = torch.log(mixed_probs) - teacher_lprobs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32) - kld = (teacher_probs * (teacher_lprobs - mixed_lprobs)) - inf_mask = logits.isinf() | teacher_logits.isinf() + kld = (teacher_probs * (teacher_log_probs - mixed_lprobs)) + inf_mask = logits.isinf() | teacher_inf_mask kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - kld = kld.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - kld = kld.sum() / num_valid_elements - return kld def compute_skewed_reverse_kl_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean" + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): logits = logits / self.kd_temperature - teacher_logits = teacher_logits / self.kd_temperature - teacher_logits = teacher_logits / self.teacher_temperature student_probs = torch.softmax(logits, -1, dtype=torch.float32) - teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) mixed_probs = (1 - self.args.skew_lambda) * teacher_probs + self.args.skew_lambda * student_probs mixed_lprobs = torch.log(mixed_probs) student_lprobs = torch.log_softmax(logits, -1, dtype=torch.float32) kld = (student_probs * (student_lprobs - mixed_lprobs)) - inf_mask = logits.isinf() | teacher_logits.isinf() + inf_mask = logits.isinf() | teacher_inf_mask kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - kld = kld.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - kld = kld.sum() / num_valid_elements - return kld def compute_js_divergence( self, logits, - teacher_logits, - target, - attention_mask=None, - reduction="mean" + teacher_probs, + teacher_log_probs, + teacher_inf_mask, ): # temperature scaling logits = logits / self.kd_temperature - teacher_logits = teacher_logits / self.kd_temperature - teacher_logits = teacher_logits / self.teacher_temperature probs = torch.softmax(logits, -1, dtype=torch.float32) - teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32) m_probs = (probs + teacher_probs) / 2 lprobs = torch.log(probs + 1e-9) - teacher_lprobs = torch.log(teacher_probs + 1e-9) + teacher_log_probs = torch.log(teacher_probs + 1e-9) m_lprobs = torch.log(m_probs + 1e-9) - kld1 = teacher_probs * (teacher_lprobs - m_lprobs) + kld1 = teacher_probs * (teacher_log_probs - m_lprobs) kld2 = probs * (lprobs - m_lprobs) kld = (kld1 + kld2) / 2 - inf_mask = logits.isinf() | teacher_logits.isinf() + inf_mask = logits.isinf() | teacher_inf_mask kld = kld.masked_fill_(inf_mask, 0.0).sum(-1) - if reduction == "sum": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - kld = kld.sum() - elif reduction == "mean": - if attention_mask is None: - pad_mask = target.eq(self.padding_id) - else: - pad_mask = attention_mask.eq(0) - kld = kld.masked_fill_(pad_mask, 0.0) - num_valid_elements = (~pad_mask).sum().float() - kld = kld.sum() / num_valid_elements - - return kld + return kld \ No newline at end of file diff --git a/roll/pipeline/dpo/dpo_config.py b/roll/pipeline/dpo/dpo_config.py index 367e3e5ac..cdc38afba 100644 --- a/roll/pipeline/dpo/dpo_config.py +++ b/roll/pipeline/dpo/dpo_config.py @@ -75,6 +75,8 @@ def __post_init__(self): self.actor_train.name = "actor_train" self.reference.name = "reference" + self.validate_worker_config() + def set_max_steps(self, max_steps: int): self.max_steps = max_steps self.actor_train.training_args.max_steps = max_steps diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index eb0c19fc6..ba51d62ef 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -252,5 +252,7 @@ def __post_init__(self): else: self.num_nodes = (max_gpu_num + self.num_gpus_per_node - 1) // self.num_gpus_per_node + self.validate_worker_config() + def to_dict(self): return dataclasses.asdict(self) diff --git a/roll/pipeline/sft/sft_config.py b/roll/pipeline/sft/sft_config.py index d23fab07e..bf60f429e 100644 --- a/roll/pipeline/sft/sft_config.py +++ b/roll/pipeline/sft/sft_config.py @@ -59,5 +59,7 @@ def __post_init__(self): self.sft_train.name = "sft_train" + self.validate_worker_config() + def set_max_steps(self, max_steps: int): self.sft_train.training_args.max_steps = max_steps diff --git a/roll/pipeline/sft/sft_pipeline.py b/roll/pipeline/sft/sft_pipeline.py index 176aba63b..163ea2751 100644 --- a/roll/pipeline/sft/sft_pipeline.py +++ b/roll/pipeline/sft/sft_pipeline.py @@ -22,7 +22,6 @@ logger = get_logger() -# TODO: support packing def preprocess_dataset(dataset, prompt_len, encode_func, num_proc): logger.info(f"Begin process dataset: {dataset}") dataset = dataset.map( diff --git a/roll/pipeline/sft/sft_worker.py b/roll/pipeline/sft/sft_worker.py index ceca505b1..8d63bd51e 100644 --- a/roll/pipeline/sft/sft_worker.py +++ b/roll/pipeline/sft/sft_worker.py @@ -31,7 +31,13 @@ def initialize(self, pipeline_config): def train_step(self, data: DataProto): data = data.to(current_platform.device_type) data = self.strategy.get_data_input(data) - metrics = self.strategy.train_step(batch=data, loss_func=self.loss_func) + + loss_func = self.loss_func + if self.worker_config.use_sequence_packing: + from roll.utils.sequence_packing import SequencePackingSFTLossWrapper + loss_func = SequencePackingSFTLossWrapper(self.strategy, loss_func) + + metrics = self.strategy.train_step(batch=data, loss_func=loss_func) output = DataProto(meta_info={"metrics": metrics}).to("cpu") return output @@ -62,4 +68,4 @@ def do_checkpoint(self, global_step): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): labels = data.batch["labels"] - return self.strategy.op_compute_language_loss(output_tensor, labels) + return self.strategy.op_compute_language_loss(output_tensor, labels) \ No newline at end of file diff --git a/roll/utils/sequence_packing.py b/roll/utils/sequence_packing.py new file mode 100644 index 000000000..ec9c6f3a0 --- /dev/null +++ b/roll/utils/sequence_packing.py @@ -0,0 +1,356 @@ +import torch + +from roll.distributed.scheduler.protocol import DataProto +from roll.platforms import current_platform +from roll.utils.constants import IGNORE_INDEX + +""" +Loss computation wrappers for sequence packing training. +Handles unpacking model outputs and aligning with original sequence boundaries for loss calculation. +""" + + +# TODO: use view of tensor in loss caculating instead of copy +class SequencePackingLossWrapper: + """ + Base wrapper for computing loss on packed sequences. + + In sequence packing, multiple sequences are concatenated and padded to form a single packed sequence. + This wrapper handles: + 1. Unpacking model outputs back to individual sequences + 2. Aligning original data (labels, masks) with unpacked outputs + 3. Computing loss on properly aligned data + """ + + def __init__( + self, + strategy, + loss_func, + ): + """ + Args: + strategy: Training strategy containing model and distributed config + loss_func: Loss function to apply + cu_seqlens_q: Cumulative sequence lengths of original (unpadded) sequences + cu_seqlens_q_padded: Cumulative sequence lengths after padding for packing + logger: Optional logger + """ + self.strategy = strategy + self.loss_func = loss_func + self.cu_seqlens = None + self.cu_seqlens_padded = None + self.logger = None + + def set_packing_params(self, cu_seqlens, cu_seqlens_padded, logger): + self.cu_seqlens = cu_seqlens + self.cu_seqlens_padded = cu_seqlens_padded + self.logger = logger + + def _unpack_output_tensor(self, output_tensor): + """ + Unpack model output tensor from packed format back to individual sequences. + + The packed output contains multiple sequences concatenated together. This method + splits them back using padded cumulative sequence lengths, accounting for context + parallelism partitioning. + + Args: + output_tensor: Packed model output with shape (batch=1, packed_seq_len, hidden_dim) + + Returns: + List of unpacked tensors, one per original sequence, each with shape + (batch=1, padded_seq_len, hidden_dim) + """ + cp_size = self.strategy.worker.rank_info.cp_size + + # Calculate sequence boundaries in the packed tensor + # Padded cumulative lengths mark where each sequence starts/ends after packing + padded_cu_seqlens = self.cu_seqlens_padded + + # Adjust for context parallelism: each rank only holds a portion of the sequence + seq_starts = padded_cu_seqlens[:-1] // cp_size + seq_ends = padded_cu_seqlens[1:] // cp_size + + # Extract each sequence from the packed tensor + unpacked_output_tensor_list = [] + for seq_idx, (seq_start, seq_end) in enumerate(zip(seq_starts, seq_ends)): + unpacked_output_tensor_list.append(output_tensor[:, seq_start:seq_end, :]) + return unpacked_output_tensor_list + + def _pad_tensor_to_target_length(self, tensor, target_length, pad_val=0, pad_dim=0): + """ + Pad tensor along the specified dimension to reach the target length by padding on the right. + + Args: + tensor: Input tensor to pad + target_length: Desired length along pad_dim + pad_val: Value to use for padding + pad_dim: Dimension to pad along + + Returns: + Padded tensor with length target_length along pad_dim + """ + seq_len = tensor.shape[pad_dim] + + if target_length > seq_len: + pad_size = target_length - seq_len + + # Construct padding specification for torch.nn.functional.pad + # Format: [pad_left, pad_right] for each dim from last to first + pad_list = [0, 0] * tensor.ndim + pad_list[2 * (tensor.ndim - 1 - pad_dim) + 1] = pad_size + + tensor = torch.nn.functional.pad(tensor, pad_list, value=pad_val) + + return tensor + + def _align_to_unpacked_output_tensor_shape(self, tensor, pad_val=0): + """ + Align original data tensors (labels, masks) to match unpacked output shape. + + Original data comes in shape (batch, max_seq_len, ...) where batch contains multiple + sequences with varying actual lengths. This method: + 1. Extracts each sequence's valid portion (up to its original unpadded length) + 2. Pads it to match the padded length used during packing + + This ensures original data aligns with unpacked model outputs for loss computation. + + Args: + tensor: Original data tensor with shape (batch, seq_len, ...) + pad_val: Value used for padding (e.g., IGNORE_INDEX for labels, 0 for masks) + + Returns: + List of aligned tensors, each with shape (1, padded_seq_len, ...) matching + the corresponding unpacked output tensor + """ + # Get original unpadded sequence lengths (actual data before packing) + unpadded_seq_lengths = self.cu_seqlens[1:] - self.cu_seqlens[:-1] + + # Get padded sequence lengths (after padding during packing) + padded_seq_lengths = self.cu_seqlens_padded[1:] - self.cu_seqlens_padded[:-1] + + source_seq_lengths = unpadded_seq_lengths # Valid data length + target_seq_lengths = padded_seq_lengths # Target length after packing + + aligned_tensor_list = [] + for seq_idx, (source_len, target_len) in enumerate( + zip(source_seq_lengths, target_seq_lengths) + ): + # Extract valid portion: truncate to original unpadded length + seq_tensor = tensor[seq_idx:seq_idx + 1, :source_len] + + # Pad to match the padded length used in packing + seq_tensor = self._pad_tensor_to_target_length(seq_tensor, target_len, pad_val=pad_val, pad_dim=1) + + # Keep batch dimension (1) to match unpacked output format + aligned_tensor_list.append(seq_tensor) + + return aligned_tensor_list + + def __call__(self, data: DataProto, output_tensor: torch.Tensor): + return self.loss_func(data, output_tensor) + + +# SFT +class SequencePackingSFTLossWrapper(SequencePackingLossWrapper): + """ + Wrapper for SFT loss computation with packed sequences. + + For SFT, labels are already packed in the same format as model outputs, + so we can directly compute loss without unpacking. + """ + + def __call__(self, data: DataProto, output_tensor: torch.Tensor): + # Use pre-packed labels that match the packed output format + labels = data.meta_info['labels_packed'] + return self.loss_func(DataProto.from_dict(tensors={'labels': labels}), output_tensor) + + +# Distillation +class SequencePackingDistillForwardWrapper(SequencePackingLossWrapper): + """ + Wrapper for teacher model forward pass in distillation with packed sequences. + + Computes teacher logits from packed outputs and prepares them for student training: + 1. Unpacks teacher outputs to individual sequences + 2. Computes full vocabulary logits or topk logits for each sequence + 3. Pads logits back to original max sequence length for easy alignment with student + """ + + def __init__(self, strategy, loss_func): + super().__init__(strategy, loss_func) + self.forward_func = loss_func + + def __call__(self, data: DataProto, output_tensor: torch.Tensor, non_loss_data: bool = True): + """ + Compute teacher logits from packed outputs. + + Args: + data: Input data protocol + output_tensor: Packed teacher model outputs + non_loss_data: Flag indicating this is for data generation, not loss computation + + Returns: + Tuple of (dummy_loss, dict with teacher logits and topk indices) + """ + # Step 1: Unpack teacher outputs to individual sequences + unpacked_output_tensor_list = self._unpack_output_tensor(output_tensor) + + # Step 2: Compute logits for each sequence + # Gather across tensor/context parallel ranks to get full logits + teacher_topk_probs_list = [] + teacher_topk_log_probs_list = [] + teacher_topk_indices_list = [] + teacher_topk_inf_mask_list = [] + for idx, unpacked_output_tensor in enumerate(unpacked_output_tensor_list): + # Compute logits with full vocabulary (or topk for efficiency) + teacher_topk_probs, teacher_topk_log_probs, teacher_topk_indices, teacher_topk_inf_mask = self.strategy.op_compute_topk_probs_and_indices( + unpacked_output_tensor, + topk=self.strategy.worker.pipeline_config.logits_topk, + target_vocab_size=self.strategy.worker.pipeline_config.target_vocab_size, + kd_temperature=self.strategy.worker.pipeline_config.kd_temperature, + teacher_temperature=self.strategy.worker.pipeline_config.teacher_temperature + ) + + # Step 3: Pad each sequence's logits to max sequence length + # This makes them easy to align with original student data later + max_length = self.strategy.worker.pipeline_config.sequence_length + teacher_topk_probs = self._pad_tensor_to_target_length(teacher_topk_probs, max_length, pad_val=0, pad_dim=1) + teacher_topk_log_probs = self._pad_tensor_to_target_length(teacher_topk_log_probs, max_length, pad_val=0, pad_dim=1) + teacher_topk_indices = self._pad_tensor_to_target_length(teacher_topk_indices, max_length, pad_val=0, pad_dim=1) + teacher_topk_inf_mask = self._pad_tensor_to_target_length(teacher_topk_inf_mask, max_length, pad_val=1, pad_dim=1) + + teacher_topk_probs_list.append(teacher_topk_probs) + teacher_topk_log_probs_list.append(teacher_topk_log_probs) + teacher_topk_indices_list.append(teacher_topk_indices) + teacher_topk_inf_mask_list.append(teacher_topk_inf_mask) + + # Concatenate all sequences back into batch format + teacher_topk_probs = torch.cat(teacher_topk_probs_list, dim=0) + teacher_topk_log_probs = torch.cat(teacher_topk_log_probs_list, dim=0) + teacher_topk_indices = torch.cat(teacher_topk_indices_list, dim=0) + teacher_topk_inf_mask = torch.cat(teacher_topk_inf_mask_list, dim=0) + + # Return dummy loss (teacher forward doesn't compute loss) and teacher outputs + return torch.tensor(0., device=output_tensor.device), { + 'topk_probs': teacher_topk_probs.detach(), + 'topk_log_probs': teacher_topk_log_probs.detach(), + 'topk_indices': teacher_topk_indices.detach(), + 'topk_inf_mask': teacher_topk_inf_mask.detach() + } + + +class SequencePackingDistillLossWrapper(SequencePackingLossWrapper): + """ + Wrapper for computing distillation loss with packed sequences. + + Combines language modeling loss and distillation loss: + 1. Unpacks student model outputs to individual sequences + 2. Aligns original labels and teacher outputs with unpacked student outputs + 3. Computes both standard LM loss and KL divergence with teacher for each sequence + 4. Combines losses with configurable weighting + """ + + def __call__(self, data: DataProto, output_tensor: torch.Tensor): + """ + Compute combined distillation and language modeling loss. + + Args: + data: Input data containing original labels and masks + output_tensor: Packed student model outputs + + Returns: + Tuple of (total_loss, metrics_dict) + """ + # Step 1: Compute student logits from packed outputs + # Keep them partitioned across tensor/context parallel for memory efficiency + student_logits = output_tensor + + # Step 2: Unpack student logits to individual sequences (still cp-partitioned) + student_logits_list = self._unpack_output_tensor(student_logits) + + # Step 3: Get original data from dataloader (not packed) + labels = data.batch['labels_for_loss'] + attention_mask = data.batch['attention_mask'] + + # Step 4: Align original data with unpacked outputs + # Truncate to original length and pad to match packing padding + aligned_labels_list = self._align_to_unpacked_output_tensor_shape(labels, pad_val=IGNORE_INDEX) + aligned_attention_mask_list = self._align_to_unpacked_output_tensor_shape(attention_mask, pad_val=0) + + # Step 5: Get and align teacher outputs (pre-computed in teacher forward pass) + if self.strategy.worker.teacher_probs_iterator is not None: + teacher_probs = next(self.strategy.worker.teacher_probs_iterator) + aligned_teacher_probs_list = self._align_to_unpacked_output_tensor_shape(teacher_probs) + else: + teacher_probs = None + if self.strategy.worker.teacher_log_probs_iterator is not None: + teacher_log_probs = next(self.strategy.worker.teacher_log_probs_iterator) + aligned_teacher_log_probs_list = self._align_to_unpacked_output_tensor_shape(teacher_log_probs) + else: + teacher_log_probs = None + if self.strategy.worker.teacher_topk_indices_iterator is not None: + teacher_topk_indices = next(self.strategy.worker.teacher_topk_indices_iterator) + aligned_teacher_topk_indices_list = self._align_to_unpacked_output_tensor_shape(teacher_topk_indices) + else: + teacher_topk_indices = None + if self.strategy.worker.teacher_inf_mask_iterator is not None: + teacher_inf_mask = next(self.strategy.worker.teacher_inf_mask_iterator) + aligned_teacher_inf_mask_list = self._align_to_unpacked_output_tensor_shape(teacher_inf_mask) + else: + teacher_inf_mask = None + + + # Step 6: Accumulate losses across all sequences in the batch + total_gpt_loss = torch.tensor(0, device=current_platform.device_type, dtype=torch.float32) + total_distill_loss = torch.tensor(0, device=current_platform.device_type, dtype=torch.float32) + total_valid_tokens = 0 + total_valid_tokens_distill = 0 + + batch_size = len(student_logits_list) + for idx in range(batch_size): + # Get aligned data for this sequence + single_student_logits = student_logits_list[idx] + single_label = aligned_labels_list[idx] + single_teacher_probs = aligned_teacher_probs_list[idx] if teacher_probs is not None else None + single_teacher_log_probs = aligned_teacher_log_probs_list[idx] if teacher_log_probs is not None else None + single_teacher_topk_indices = aligned_teacher_topk_indices_list[idx] if teacher_topk_indices is not None else None + single_teacher_inf_mask = aligned_teacher_inf_mask_list[idx] if teacher_inf_mask is not None else None + + # Compute standard language modeling loss (cross-entropy with labels) + local_gpt_loss, local_valid_tokens = self.strategy.op_compute_language_loss_from_logits( + single_student_logits, single_label, + reduction="sum") + total_gpt_loss += local_gpt_loss + total_valid_tokens += local_valid_tokens + + # Compute distillation loss (KL divergence between student and teacher) + local_distill_loss, local_valid_tokens_distill = self.strategy.op_compute_various_divergence( + self.strategy.worker.kl_loss_func, + single_student_logits, single_teacher_probs, + single_teacher_log_probs, single_teacher_topk_indices, + single_teacher_inf_mask, single_label, + attention_mask=None, reduction="sum") + + total_distill_loss += local_distill_loss + total_valid_tokens_distill += local_valid_tokens_distill + + # Step 7: Normalize losses by number of valid tokens + if total_valid_tokens == 0: + total_valid_tokens = 1 + if total_valid_tokens_distill == 0: + total_valid_tokens_distill = 1 + gpt_loss = total_gpt_loss / total_valid_tokens + distill_loss = total_distill_loss / total_valid_tokens_distill + + # Step 8: Combine losses with configured weighting + # loss = (1 - α) * LM_loss + α * distill_loss + loss = ((1 - self.strategy.worker.pipeline_config.distill_loss_weight) * gpt_loss + + self.strategy.worker.pipeline_config.distill_loss_weight * distill_loss) + + student_metrics = { + "train/loss": loss.detach().item(), + "train/train_distill_loss": distill_loss.detach().item(), + "train/train_student_loss": gpt_loss.detach().item(), + } + return loss, student_metrics From 4a68470e3ff4b12f64cbb4d7995b4da246d99015 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Mon, 24 Nov 2025 15:30:13 +0800 Subject: [PATCH 49/58] (feat): add alive check. --- roll/pipeline/base_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index efbfad175..10744a66e 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -404,13 +404,16 @@ def add_request(self, command, data: DataProto): response_callback_fn: callable generation_config, 按request设置 """ - if command == GenerateRequestType.ALIVE_CHECK: + def alive_check(): if self.thread_server is not None: if not self.thread_server.is_alive(): raise Exception("thread server has stopped unexpectedly. check stderr for more info.") + if command == GenerateRequestType.ALIVE_CHECK: + alive_check() output = DataProto(meta_info={"request_counts": len(self.response_call_back_fns)}) return output elif command == GenerateRequestType.ADD: + alive_check() assert "response_callback_fn" in data.meta_info, "response_callback_fn is not in data.meta_info" is_num_return_sequences_expand = data.meta_info.get("is_num_return_sequences_expand", False) if "generation_config" not in data.meta_info: From 21460df7039b0716855b281d5f56d11944bb43e2 Mon Sep 17 00:00:00 2001 From: "tianhe.lzd" Date: Tue, 25 Nov 2025 10:37:48 +0800 Subject: [PATCH 50/58] (feat): sglang support dp-attention. --- roll/third_party/sglang/v052_patch/engine.py | 3 ++- roll/third_party/sglang/v052_patch/scheduler.py | 9 +++++++++ .../sglang/v052_patch/tokenizer_manager.py | 15 --------------- roll/third_party/sglang/v054_patch/engine.py | 4 +++- roll/third_party/sglang/v054_patch/scheduler.py | 7 +++++++ .../sglang/v054_patch/tokenizer_manager.py | 15 --------------- 6 files changed, 21 insertions(+), 32 deletions(-) diff --git a/roll/third_party/sglang/v052_patch/engine.py b/roll/third_party/sglang/v052_patch/engine.py index 055766968..12fe03aa1 100644 --- a/roll/third_party/sglang/v052_patch/engine.py +++ b/roll/third_party/sglang/v052_patch/engine.py @@ -103,10 +103,11 @@ def __init__(self, _launch_subprocesses): def __call__(self, *args, **kwargs): import sys from roll.third_party.sglang.v052_patch.tokenizer_manager import TokenizerManagerSA - from roll.third_party.sglang.v052_patch.scheduler import run_scheduler_process + from roll.third_party.sglang.v052_patch.scheduler import run_scheduler_process, run_data_parallel_controller_process sys.modules['sglang.srt.entrypoints.engine'].__dict__['TokenizerManager'] = TokenizerManagerSA sys.modules['sglang.srt.entrypoints.engine'].__dict__['run_scheduler_process'] = run_scheduler_process + sys.modules['sglang.srt.entrypoints.engine'].__dict__['run_data_parallel_controller_process'] = run_data_parallel_controller_process return self._launch_subprocesses(*args, **kwargs) diff --git a/roll/third_party/sglang/v052_patch/scheduler.py b/roll/third_party/sglang/v052_patch/scheduler.py index 4f74e28e0..48405d4df 100644 --- a/roll/third_party/sglang/v052_patch/scheduler.py +++ b/roll/third_party/sglang/v052_patch/scheduler.py @@ -97,3 +97,12 @@ def run_scheduler_process(*args, **kwargs): from sglang.srt.managers.scheduler import run_scheduler_process return run_scheduler_process(*args, **kwargs) + +def run_data_parallel_controller_process(*args, **kwargs): + import sys + sys.modules['sglang.srt.managers.data_parallel_controller'].__dict__['run_scheduler_process'] = run_scheduler_process + from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, + ) + return run_data_parallel_controller_process(*args, **kwargs) + diff --git a/roll/third_party/sglang/v052_patch/tokenizer_manager.py b/roll/third_party/sglang/v052_patch/tokenizer_manager.py index c3708bbea..fd84c0f3c 100644 --- a/roll/third_party/sglang/v052_patch/tokenizer_manager.py +++ b/roll/third_party/sglang/v052_patch/tokenizer_manager.py @@ -72,9 +72,6 @@ async def setup_collective_group( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.setup_collective_group_communicator(obj))[0] return result.success, result.message @@ -84,9 +81,6 @@ async def broadcast_bucket( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.broadcast_bucket_communicator(obj))[0] return result.success, result.message @@ -96,9 +90,6 @@ async def broadcast_parameter( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.broadcast_parameter_communicator(obj))[0] return result.success, result.message @@ -108,9 +99,6 @@ async def update_parameter( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.update_parameter_communicator(obj))[0] return result.success, result.message @@ -120,8 +108,5 @@ async def update_parameter_in_bucket( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.update_parameter_in_bucket_communicator(obj))[0] return result.success, result.message \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/engine.py b/roll/third_party/sglang/v054_patch/engine.py index 99489347c..df7f7ba56 100644 --- a/roll/third_party/sglang/v054_patch/engine.py +++ b/roll/third_party/sglang/v054_patch/engine.py @@ -103,11 +103,13 @@ def __init__(self, _launch_subprocesses): def __call__(self, *args, **kwargs): import sys from roll.third_party.sglang.v054_patch.tokenizer_manager import TokenizerManagerSA - from roll.third_party.sglang.v054_patch.scheduler import run_scheduler_process + from roll.third_party.sglang.v054_patch.scheduler import run_scheduler_process, run_data_parallel_controller_process sys.modules['sglang.srt.entrypoints.engine'].__dict__['TokenizerManager'] = TokenizerManagerSA sys.modules['sglang.srt.entrypoints.engine'].__dict__['run_scheduler_process'] = run_scheduler_process + sys.modules['sglang.srt.entrypoints.engine'].__dict__['run_data_parallel_controller_process'] = run_data_parallel_controller_process return self._launch_subprocesses(*args, **kwargs) + engine_module._launch_subprocesses = _roll_launch_subprocesses(engine_module._launch_subprocesses) \ No newline at end of file diff --git a/roll/third_party/sglang/v054_patch/scheduler.py b/roll/third_party/sglang/v054_patch/scheduler.py index a5cb49b1d..ed87999ad 100644 --- a/roll/third_party/sglang/v054_patch/scheduler.py +++ b/roll/third_party/sglang/v054_patch/scheduler.py @@ -96,3 +96,10 @@ def run_scheduler_process(*args, **kwargs): from sglang.srt.managers.scheduler import run_scheduler_process return run_scheduler_process(*args, **kwargs) +def run_data_parallel_controller_process(*args, **kwargs): + import sys + sys.modules['sglang.srt.managers.data_parallel_controller'].__dict__['run_scheduler_process'] = run_scheduler_process + from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, + ) + return run_data_parallel_controller_process(*args, **kwargs) diff --git a/roll/third_party/sglang/v054_patch/tokenizer_manager.py b/roll/third_party/sglang/v054_patch/tokenizer_manager.py index c3708bbea..fd84c0f3c 100644 --- a/roll/third_party/sglang/v054_patch/tokenizer_manager.py +++ b/roll/third_party/sglang/v054_patch/tokenizer_manager.py @@ -72,9 +72,6 @@ async def setup_collective_group( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.setup_collective_group_communicator(obj))[0] return result.success, result.message @@ -84,9 +81,6 @@ async def broadcast_bucket( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.broadcast_bucket_communicator(obj))[0] return result.success, result.message @@ -96,9 +90,6 @@ async def broadcast_parameter( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.broadcast_parameter_communicator(obj))[0] return result.success, result.message @@ -108,9 +99,6 @@ async def update_parameter( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.update_parameter_communicator(obj))[0] return result.success, result.message @@ -120,8 +108,5 @@ async def update_parameter_in_bucket( request: Optional[fastapi.Request] = None, ) -> Tuple[bool, str]: self.auto_create_handle_loop() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" result = (await self.update_parameter_in_bucket_communicator(obj))[0] return result.success, result.message \ No newline at end of file From 1c45b7a63750131b990d2fab466f049b6de1547b Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Wed, 3 Dec 2025 21:07:49 +0800 Subject: [PATCH 51/58] (fix): set broadcast_non_tensor_batch for old_logprobs. --- roll/configs/base_config.py | 16 +++++++++------- roll/pipeline/agentic/agentic_pipeline.py | 2 +- roll/pipeline/base_worker.py | 16 ++++++++-------- roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py | 2 +- roll/pipeline/rlvr/rlvr_pipeline.py | 3 ++- roll/pipeline/rlvr/rlvr_vlm_pipeline.py | 2 +- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index ca7100a2c..aa1da4c27 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -398,8 +398,8 @@ class PPOConfig(BaseConfig): enable_reference: bool = field( default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} ) - enable_old_logprobs: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) - force_disable_old_logprobs: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs"}) + enable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) + force_disable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs_recompute"}) def __post_init__(self): super().__post_init__() @@ -427,11 +427,13 @@ def __post_init__(self): if self.use_kl_loss or self.init_kl_coef > 0: logger.warning(f"use_kl_loss or init_kl_coef > 0, enable_reference = True") self.enable_reference = True - if self.force_disable_old_logprobs: - self.enable_old_logprobs = False + if self.force_disable_old_logprobs_recompute: + self.enable_old_logprobs_recompute = False else: self.set_old_logprobs_status() + logger.info(f"enable_old_logprobs_recompute: {self.enable_old_logprobs_recompute}\tenable_reference: {self.enable_reference}") + def set_max_steps(self, max_steps: int): actor_backward_batch_size = ( self.actor_train.training_args.per_device_train_batch_size @@ -487,11 +489,11 @@ def set_old_logprobs_status(self): if backward_steps_per_rank > 1: # Multiple backward steps means model parameters change during training # Cannot reuse cached logprobs across backward passes - self.enable_old_logprobs = True + self.enable_old_logprobs_recompute = True if self.init_kl_coef > 0: - logger.warning(f"init_kl_coef > 0, enable_old_logprobs = True") - self.enable_old_logprobs = True + logger.warning(f"init_kl_coef > 0, enable_old_logprobs_recompute = True") + self.enable_old_logprobs_recompute = True @property def async_pipeline(self) -> bool: diff --git a/roll/pipeline/agentic/agentic_pipeline.py b/roll/pipeline/agentic/agentic_pipeline.py index 925d56fcf..4e666e6fd 100644 --- a/roll/pipeline/agentic/agentic_pipeline.py +++ b/roll/pipeline/agentic/agentic_pipeline.py @@ -207,7 +207,7 @@ def run(self): with Timer(name="cal_old_log_probs_values", logger=None) as cal_old_logpb_timer: batch.meta_info["is_offload_states"] = False - if self.pipeline_config.enable_old_logprobs: + if self.pipeline_config.enable_old_logprobs_recompute: old_log_probs: DataProto = self.actor_train.compute_log_probs(batch, blocking=True) batch.batch["old_log_probs"] = old_log_probs.batch["log_probs"] avg_old_log_prob = masked_mean(batch.batch["old_log_probs"], batch.batch["response_mask"][:, 1:]) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 10744a66e..d5c84c120 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -54,7 +54,7 @@ def initialize(self, pipeline_config): self.strategy.initialize(model_provider=default_diffusion_module_provider) else: self.strategy.initialize(model_provider=default_actor_model_provider) - + self.tokenizer = self.strategy.tokenizer if self.pipeline_config.resume_from_checkpoint: load_dir = download_model(self.pipeline_config.resume_from_checkpoint) @@ -221,7 +221,7 @@ def stop_server(self, data: DataProto = None): def compute_log_probs(self, data: DataProto): """ return DataProto.from_dict(tensors={'log_probs': output}) - """ + """ global_step = data.meta_info.get("global_step", 0) is_offload_states = data.meta_info.get("is_offload_states", True) metrics = {} @@ -261,7 +261,7 @@ def forward_func_log_probs(self, data: DataProto, output_tensor: torch.Tensor): def get_old_log_probs_with_cache(self, data: DataProto, log_probs: torch.Tensor) -> torch.Tensor: """ - Get old_log_probs with intra-step caching when enable_old_logprobs == False. + Get old_log_probs with intra-step caching when enable_old_logprobs_recompute == False. When caching is enabled, the first forward pass log_probs can be reused as old_log_probs since they are mathematically equivalent in on-policy settings. This method can be overridden by subclasses for custom caching behavior. @@ -274,8 +274,8 @@ def get_old_log_probs_with_cache(self, data: DataProto, log_probs: torch.Tensor) old_log_probs tensor (detached, no gradients) """ # Original computation path when caching is disabled - if self.pipeline_config.enable_old_logprobs or "sample_uuid" not in data.non_tensor_batch: - # When enable_old_logprobs=True, use the pre-computed old_log_probs from batch + if self.pipeline_config.enable_old_logprobs_recompute or "sample_uuid" not in data.non_tensor_batch: + # When enable_old_logprobs_recompute=True, use the pre-computed old_log_probs from batch return data.batch["old_log_probs"] sample_uuids = data.non_tensor_batch["sample_uuid"] @@ -289,13 +289,13 @@ def get_old_log_probs_with_cache(self, data: DataProto, log_probs: torch.Tensor) for sample_uuid in sample_uuids: cached_old_log_probs.append(self._logprobs_cache[sample_uuid]) - old_log_probs = torch.cat(cached_old_log_probs, dim=0) + old_log_probs = torch.cat(cached_old_log_probs, dim=0).to(current_platform.device_type) else: # Cache miss - use current log_probs as old_log_probs (mathematically equivalent in on-policy) old_log_probs = log_probs.detach() if self.pipeline_config.ppo_epochs > 1: for i, sample_uuid in enumerate(sample_uuids): - self._logprobs_cache[sample_uuid] = old_log_probs[i:i+1] + self._logprobs_cache[sample_uuid] = old_log_probs[i : i + 1].cpu() return old_log_probs @@ -318,7 +318,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ratio = (log_probs - old_log_probs).exp() pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip - pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip + pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip surr1 = ratio * advantages surr2 = ratio.clamp(1 - pg_clip_low, 1 + pg_clip_high) * advantages pg_loss = -torch.min(surr1, surr2) diff --git a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py index f5078e19b..0bc2fc664 100644 --- a/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_math_vlm_pipeline.py @@ -390,7 +390,7 @@ def run(self): if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - if self.pipeline_config.enable_old_logprobs: + if self.pipeline_config.enable_old_logprobs_recompute: old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs( batch, blocking=False ) diff --git a/roll/pipeline/rlvr/rlvr_pipeline.py b/roll/pipeline/rlvr/rlvr_pipeline.py index bd5b469cd..77ba2fabd 100644 --- a/roll/pipeline/rlvr/rlvr_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_pipeline.py @@ -538,6 +538,7 @@ def run(self): batch = generate_output batch.meta_info["global_step"] = global_step + batch.meta_info["_broadcast_non_tensor_batch"] = True batch.non_tensor_batch['sample_uuid'] = np.array([str(uuid.uuid4()) for _ in range(batch.batch.shape[0])], dtype=object) @@ -570,7 +571,7 @@ def run(self): if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - if self.pipeline_config.enable_old_logprobs: + if self.pipeline_config.enable_old_logprobs_recompute: if self.pipeline_config.actor_train.use_dynamic_batching_in_infer: batch, dynamic_batching_metrics = dynamic_batching_shard( batch, diff --git a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py index 052175c46..acafdd29a 100644 --- a/roll/pipeline/rlvr/rlvr_vlm_pipeline.py +++ b/roll/pipeline/rlvr/rlvr_vlm_pipeline.py @@ -569,7 +569,7 @@ def run(self): if self.pipeline_config.adv_estimator == "gae": values_refs: List[ray.ObjectRef] = self.critic.compute_values(batch, blocking=False) - if self.pipeline_config.enable_old_logprobs: + if self.pipeline_config.enable_old_logprobs_recompute: old_log_probs_refs: List[ray.ObjectRef] = self.actor_train.compute_log_probs(batch, blocking=False) old_log_probs = DataProto.materialize_concat(data_refs=old_log_probs_refs) agg_entropy = agg_loss( From 24374f10ac57173184eabe218a8d92093a528071 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Thu, 4 Dec 2025 13:31:14 +0800 Subject: [PATCH 52/58] (fix): fix vllm get_metrics exception. --- roll/distributed/strategy/vllm_strategy.py | 37 ++++++++++++---------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/roll/distributed/strategy/vllm_strategy.py b/roll/distributed/strategy/vllm_strategy.py index 21c5f6651..663bdf804 100644 --- a/roll/distributed/strategy/vllm_strategy.py +++ b/roll/distributed/strategy/vllm_strategy.py @@ -450,23 +450,26 @@ def add_lora(self, peft_config): def _collect_metrics_snapshot(self): """Collect metrics snapshots periodically in a background thread.""" - while True: - raw_metrics = self.model.get_metrics() - snapshot = { - 'vllm/kv_cache_usage_perc_max': [], - 'vllm/num_requests_waiting_max': [], - 'vllm/num_preemptions_max': [] - } - for metric in raw_metrics: - if metric.name == "vllm:kv_cache_usage_perc": - snapshot['vllm/kv_cache_usage_perc_max'].append(metric.value) - elif metric.name == "vllm:num_requests_waiting": - snapshot['vllm/num_requests_waiting_max'].append(metric.value) - elif metric.name == "vllm:num_preemptions": - snapshot['vllm/num_preemptions_max'].append(metric.value) - self._metrics_snapshots.append(snapshot) - - time.sleep(self._metrics_snapshot_interval) + try: + while True: + raw_metrics = self.model.get_metrics() + snapshot = { + 'vllm/kv_cache_usage_perc_max': [], + 'vllm/num_requests_waiting_max': [], + 'vllm/num_preemptions_max': [] + } + for metric in raw_metrics: + if metric.name == "vllm:kv_cache_usage_perc": + snapshot['vllm/kv_cache_usage_perc_max'].append(metric.value) + elif metric.name == "vllm:num_requests_waiting": + snapshot['vllm/num_requests_waiting_max'].append(metric.value) + elif metric.name == "vllm:num_preemptions": + snapshot['vllm/num_preemptions_max'].append(metric.value) + self._metrics_snapshots.append(snapshot) + + time.sleep(self._metrics_snapshot_interval) + except Exception as e: + logger.warning(f"Failed to get metrics: {e}") def get_metrics(self, metric_names: Optional[List[str]] = None) -> Dict[str, float]: """ From 61a544a48c74a5623c00ff4bf2d2fc76cc66226c Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Thu, 4 Dec 2025 13:37:00 +0800 Subject: [PATCH 53/58] (fix): fix vllm 0110. --- roll/third_party/vllm/vllm_0_11_0/llm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/roll/third_party/vllm/vllm_0_11_0/llm.py b/roll/third_party/vllm/vllm_0_11_0/llm.py index b5b3aa48d..e9a117880 100644 --- a/roll/third_party/vllm/vllm_0_11_0/llm.py +++ b/roll/third_party/vllm/vllm_0_11_0/llm.py @@ -198,8 +198,9 @@ def __init__( # Load the Input/Output processor plugin if any self.model_config = self.llm_engine.model_config - self.processor = self.llm_engine.processor - self.io_processor = self.llm_engine.io_processor + io_processor_plugin = self.llm_engine.model_config.io_processor_plugin + self.io_processor = get_io_processor(self.llm_engine.vllm_config, + io_processor_plugin) def load_states(self): self.collective_rpc(method="load_states") From e1695f2dce9531f982196412a3ea3fd71d393ef1 Mon Sep 17 00:00:00 2001 From: "xiongshaopan.xsp" Date: Thu, 4 Dec 2025 15:34:06 +0800 Subject: [PATCH 54/58] (fix): fix AgenticAcotrWorker import. --- roll/pipeline/agentic/agentic_actor_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index a7f683e92..75510c675 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -4,7 +4,6 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl -from roll.pipeline.agentic.utils import compute_segment_masked_mean class ActorWorker(BaseActorWorker): From a595ec3d5f09091b3b8c6f2196d056d3e2264f77 Mon Sep 17 00:00:00 2001 From: millioniron Date: Wed, 3 Dec 2025 14:08:18 +0800 Subject: [PATCH 55/58] new file: docs_roll/docs/User Guides/Configuration/infer_correction.md new file: examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file: examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file: examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file: examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh modified: roll/configs/base_config.py modified: roll/configs/generating_args.py modified: roll/distributed/scheduler/generate_scheduler.py modified: roll/pipeline/agentic/env_manager/step_env_manager.py modified: roll/pipeline/agentic/env_manager/traj_env_manager.py modified: roll/pipeline/base_worker.py modified: roll/pipeline/rlvr/actor_pg_worker.py modified: roll/pipeline/rlvr/actor_worker.py --- .../Configuration/infer_correction.md | 142 ++++++++++ .../agentic_webshop_infer_correction.yaml | 183 ++++++++++++ .../rlvr_infer_correction_config.yaml | 265 ++++++++++++++++++ .../run_agentic_pipeline_webshop.sh | 10 + .../run_rlvr_pipeline.sh | 5 + roll/configs/base_config.py | 56 ++++ roll/configs/generating_args.py | 4 + .../agentic/env_manager/step_env_manager.py | 1 + .../agentic/env_manager/traj_env_manager.py | 3 +- roll/pipeline/base_worker.py | 165 ++++++++++- roll/pipeline/rlvr/actor_pg_worker.py | 10 +- roll/pipeline/rlvr/actor_worker.py | 58 +--- 12 files changed, 846 insertions(+), 56 deletions(-) create mode 100644 docs_roll/docs/User Guides/Configuration/infer_correction.md create mode 100644 examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml create mode 100644 examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml create mode 100644 examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh create mode 100644 examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh diff --git a/docs_roll/docs/User Guides/Configuration/infer_correction.md b/docs_roll/docs/User Guides/Configuration/infer_correction.md new file mode 100644 index 000000000..ecdce6325 --- /dev/null +++ b/docs_roll/docs/User Guides/Configuration/infer_correction.md @@ -0,0 +1,142 @@ +# 训推差异修复 + + +## 简介 + +训推差异是由于RL训练过程中,训练器和生成器之间由于后端不同(vLLM vs SGLang vs FSDP vs Megatron),精度不同(FP8 vs FP16 vs BF16 vs FP32),形成了一种类似off-policy gap,会导致训练不稳定和策略崩溃。 + + +## 实现原理 + + +修复训推差异大致可分为两种方法(1)对训练器和生成器进行策略修正(2)使用infer_log_probs直接代替old_log_probs(trainer)进行PPO ratio计算。第二种方案比较直接,我们着重说明第一种方法。 + +### 对训练器和生成器进行策略修正 + +### IS权重矫正 +通过对训练器(old_log_probs)和生成器(infer_log_prob)之间进行重要性采样矫正,弥合训推差异。与off-policy算法类似,IS权重矫正可以区分token级别和sequence级别,只能选择一个。 + +### MASK过滤掉不符合条件的样本 +与IS权重修正不同的是,此方法对于超过阈值的样本直接进行mask遮掩,过滤掉不符合的样本。涉及的方法有(1)token级别:过滤掉不符合条件的token(2)灾难性token:过滤掉出现灾难性严重偏差的token的句子样本(3)sequence级别:对sequence进行IS计算,过滤掉不符合的句子样本(4)sequence级别,使用几何平均来计算IS权重,但指标也更为敏感 + + +## 关键参数配置 + +生成器是否返回infer_log_probs + +GeneratingArguments: + +```yaml +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: fp16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 +``` +当logprobs:大于0时,返回infer_log_probs + + +------- + +参数的配置在PPOConfig和中,关键配置参数如下: + +```yaml +infer_correction: true + +infer_is_mode: token # 可选 token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: true +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: true +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: sequence 可选None sequence geometric +infer_seq_mask_threshold_min: 0.1 +infer_seq_mask_threshold_max: 10 +``` + + +### infer_correction +- **含义**:控制是否启用训推差异修复机制。若启用,系统将使用 `infer_log_probs` 对策略梯度进行矫正。 +- **默认值**:`false` + +### infer_is_mode +- **含义**:指定重要性采样(IS)权重的计算粒度。 +- **可选值**: + - `"token"`:每个 token 独立计算 IS 权重 + - `"sequence"`:基于整个 response 序列的 log-ratio 求和后广播至所有 token + - `"none"`(或未设置):不应用 IS 加权,权重恒为 1 +- **默认值**:若未设置,默认为 `"token"` +- **注意**:不可同时使用多种模式,仅能选择其一。 + +### infer_is_threshold_min +- **含义**:IS 权重的下限阈值,用于裁剪过小的权重以控制方差。 +- **默认值**:`0.0` +- **建议**:通常保留为 `0.0`,以保持无偏性下界 + +### infer_is_threshold_max +- **含义**:IS 权重的上限阈值,防止极端大的权重主导梯度。 +- **默认值**:`2.0` +- **建议**:`"token"`级别推荐为 `1.5 ~ 5.0` `"sequence"`级别推荐为2.0 - 10.0 + +### enable_token_reject +- **含义**:是否启用 token 级别的样本拒绝机制。 +- **默认值**:`false` +- **作用**:结合 `infer_token_mask_threshold_min/max`,mask 掉 IS ratio 超出合法区间的 token。 + +### infer_token_mask_threshold_min +- **含义**:token 级 IS ratio(`old_log_probs / infer_log_probs` 的指数)的下限。 +- **默认值**:`0.0` +- **典型值**:`0.0`通常可设为1/max + +### infer_token_mask_threshold_max +- **含义**:遮掩token 级 IS ratio 的上限。 +- **默认值**:`2.0` +- **典型范围**:`1.5 ~ 5.0` + +### enable_catastrophic_reject +- **含义**:是否启用“灾难性偏差”检测并拒绝整句样本。 +- **默认值**:`false` +- **触发条件**:只要序列中存在任一 token 满足 `ratio < infer_catastrophic_threshold`,则整句被 mask。 + +### infer_catastrophic_threshold +- **含义**:灾难性拒绝的 ratio 下限阈值。 +- **默认值**:`1e-4` +- **解释**:当 `infer_log_probs` 远大于 `old_log_probs`(即生成器过于“自信”),导致 `ratio = exp(old - infer)` 极小 + +### enable_seq_reject +- **含义**:是否启用序列级别的拒绝机制,以及使用何种聚合方式。 +- **可选值**: + - `null` / `false`:禁用 + - `"sequence"`:使用 log-ratio **求和** 计算序列 IS ratio + - `"geometric"`:使用 log-ratio **平均**(等价于几何平均概率)计算序列 IS ratio +- **默认值**:`null` + +### infer_seq_mask_threshold_min +- **含义**:遮掩序列级 IS ratio 的下限。 +- **默认值**:`0.1` +- **典型值**:通常可设为1/max,当使用`"geometric"`时,最好强制设为1/max + + +### infer_seq_mask_threshold_max +- **含义**:遮掩序列级 IS ratio 的上限。 +- **默认值**:`10.0` +- **典型范围**:当使用`"sequence"`时,推荐`2.0 ~ 10.0`,但随着长度增加可适当放宽。当使用`"geometric"`时,推荐设置为1.0001 - 1.001 + + + +## 使用建议 + +1. 通常情况下,old_log_prob << infer_log_porb, 阈值的下限就比较重要了。并不建议使用sequence级别的IS或MASK + diff --git a/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file mode 100644 index 000000000..78fefee47 --- /dev/null +++ b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml @@ -0,0 +1,183 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_webshop_infer_correction" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_webshop +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +#track_with: swanlab +#tracker_kwargs: +# login_kwargs: +# api_key: your_api_key +# project: roll-agentic +# logdir: debug +# experiment_name: ${exp_name} +# tags: +# - roll +# - agentic +# - debug + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_webshop + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 64 +val_batch_size: 64 +sequence_length: 8192 + +reward_clip: 20 +advantage_clip: 0.2 # 0.1-0.3 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +max_grad_norm: 1.0 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +# infer correction + +infer_correction: true + +infer_is_mode: token #token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + + +enable_seq_reject: None # None sequence geometric +infer_seq_mask_threshold_min: 0.999 +infer_seq_mask_threshold_max: 1.001 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 8 + warmup_steps: 10 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + max_grad_norm: ${max_grad_norm} + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 1024 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + logprobs: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: sglang + strategy_config: + mem_fraction_static: 0.85 + load_format: auto + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.05 + num_env_groups: 8 + group_size: 8 + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + num_env_groups: 64 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [64] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +custom_envs: + WebShopEnv: + ${custom_env.WebShopEnv} \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file mode 100644 index 000000000..e2cbac8a7 --- /dev/null +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -0,0 +1,265 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-infer-correction-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + + +# infer correction +infer_correction: true + +infer_is_mode: token +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_rs_threshold_min: 0.0 +infer_token_rs_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: None + +# enable_seq_reject: sequence +# infer_seq_rs_threshold_min: 0.1 +# infer_seq_rs_threshold_max: 10 + +# enable_seq_reject: geometric +# infer_seq_rs_threshold_min: 0.999 +# infer_seq_rs_threshold_max: 1.001 + +validation: + data_args: + template: qwen2.5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2.5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 + data_args: + template: qwen2.5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2.5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + world_size: 8 + infer_batch_size: 1 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2.5 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file mode 100644 index 000000000..ff1a59402 --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Run `git submodule update --init --recursive` to init submodules before run this script. +set +x + + +pip install -r third_party/webshop-minimal/requirements.txt --trusted-host mirrors.aliyun.com --index-url https://mirrors.aliyun.com/pypi/simple/ +python -m spacy download en_core_web_sm + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_agentic_pipeline.py --config_path $CONFIG_PATH --config_name agentic_webshop_infer_correction diff --git a/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh new file mode 100644 index 000000000..6874f4f5e --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_infer_correction_config \ No newline at end of file diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index aa1da4c27..8dc75d3f3 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -395,12 +395,68 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) + enable_reference: bool = field( default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} ) enable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) force_disable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs_recompute"}) + # trainer&rollout mismatch + infer_correction: bool = field( + default=False, + metadata={"help": "Whether to apply importance sampling correction during inference."} + ) + infer_is_mode: Literal["token", "sequence", "none"] = field( + default="token", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + ) + # Clipping thresholds (used in IS weighting) + infer_is_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum threshold for IS weight clipping. Recommended 0.0 for unbiased estimation."} + ) + infer_is_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum threshold for IS weight clipping."} + ) + # Token-level rejection + enable_token_reject: bool = field( + default=False, + metadata={"help": "Enable token-level rejection based on IS ratio thresholds."} + ) + infer_token_mask_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum IS ratio threshold for token rejection."} + ) + infer_token_mask_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum IS ratio threshold for token rejection."} + ) + # Catastrophic rejection (reject entire sequence if any token ratio is too small) + enable_catastrophic_reject: bool = field( + default=False, + metadata={"help": "Enable catastrophic rejection: reject entire sequence if any valid token has IS ratio below threshold."} + ) + infer_catastrophic_threshold: float = field( + default=1e-4, + metadata={"help": "Threshold below which a token triggers catastrophic rejection of its sequence."} + ) + # Sequence-level rejection + enable_seq_reject: Optional[Literal["sequence", "geometric",'None']] = field( + default=None, + metadata={"help": "Enable sequence-level rejection: 'sequence' uses sum of log-ratios, 'geometric' uses mean. None disables."} + ) + infer_seq_mask_threshold_min: float = field( + default=0.1, + metadata={"help": "Minimum IS ratio threshold for sequence rejection."} + ) + infer_seq_mask_threshold_max: float = field( + default=10.0, + metadata={"help": "Maximum IS ratio threshold for sequence rejection (typically larger than token-level)."} + ) + + def __post_init__(self): super().__post_init__() diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 68cf88d17..0822614ba 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -58,6 +58,10 @@ class GeneratingArguments: default=None, metadata={"help": "Whether to include the stop strings in output text."}, ) + logprobs: Optional[int] = field( + default=None, + metadata={"help": "Whether return infer log-prob."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 4348605a3..26e95106d 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -105,6 +105,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 88ab15d91..80f15113f 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig from tensordict import TensorDict from transformers import PreTrainedTokenizer - +import json from roll.pipeline.agentic.llm_proxy import create_llm_proxy, BaseLLMProxy from roll.pipeline.agentic.env_manager.base_env_manager import RolloutCache, BaseEnvManager from roll.utils.env_action_limiter import get_global_limiter @@ -308,6 +308,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): infer_logprobs.extend([0] * len(items["prompt_ids"]) + items["infer_logprobs"]) input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index d5c84c120..7148752a1 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -2,7 +2,7 @@ import threading import time from typing import Union, Optional, Dict - +import numpy import ray import torch from codetiming import Timer @@ -25,6 +25,7 @@ postprocess_generate, GenerateRequestType, agg_loss, + masked_sum ) from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType @@ -307,8 +308,10 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] + infer_log_probs = data.batch.get("infer_logprobs", None) log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] @@ -325,12 +328,18 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + pg_loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=pg_loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, - kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -372,9 +381,153 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), } + pg_metrics.update(infer_stats) return total_loss, pg_metrics - + + + def infer_correction( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 + 返回更新后的 pg_loss、mask 和详细统计信息。 + """ + # Step 0: Shape alignment + if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: + infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ + f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" + # Step 1: Compute log-ratio and ratio + log_ratio = old_log_probs - infer_log_probs # [B, T] + ratio = torch.exp(log_ratio) # [B, T] + # Step 2: Apply IS weighting strategy (optional) + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] + raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] + elif self.pipeline_config.infer_is_mode in (None, "none", ""): + raw_is_weight = torch.ones_like(ratio) + else: + raw_is_weight = torch.ones_like(ratio) + # Clamp to get final is_weight (used for loss) + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + # Step 3: Build rejection mask + original_valid = response_mask > 0.5 # [B, T], bool + keep_mask = original_valid.clone() + # (a) Token-level ratio reject + if getattr(self.pipeline_config, 'enable_token_reject', False): + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + keep_mask = keep_mask & (~token_reject) + # (b) Catastrophic reject + if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid + has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + keep_mask = keep_mask & (~has_catastrophic) + # (c) Sequence-level reject + if getattr(self.pipeline_config, 'enable_seq_reject', False): + if self.pipeline_config.enable_seq_reject=="sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_sum) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + elif self.pipeline_config.enable_seq_reject=="geometric": + log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_mean) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + # final_mask = keep_mask.float() + final_mask = keep_mask + # Step 4: Reweight policy loss + pg_loss = pg_loss * is_weight + # Step 5: Compute detailed stats over original_valid tokens + # Rejected mask + rejected_mask = original_valid & (~keep_mask) # [B, T] + # Clipped mask: only meaningful if IS weighting is active + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid + clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid + clipped_mask = clipped_low | clipped_high # [B, T] + else: + clipped_mask = torch.zeros_like(original_valid) # no clipping + # Compute fractions + def _compute_frac(mask_tensor): + return agg_loss( + loss_mat=mask_tensor.float(), + loss_mask=response_mask, + loss_agg_mode="token-mean" # force token-wise average + ).detach().item() + clip_frac = _compute_frac(clipped_mask) + reject_frac = _compute_frac(rejected_mask) + clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) + clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) + # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) + seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + ### kl metric + inferkl_orig = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=response_mask, # ← original mask + kl_penalty="kl" + ) + inferkl_final = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=final_mask, # ← after rejection + kl_penalty="kl" + ) + inferkl_orig_agg = agg_loss( + loss_mat=inferkl_orig, + loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + inferkl_final_agg = agg_loss( + loss_mat=inferkl_final, + loss_mask=final_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] + if valid_raw_is_weight.numel() > 0: + raw_is_mean = valid_raw_is_weight.mean().detach().item() + raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() + raw_is_min = valid_raw_is_weight.min().detach().item() + raw_is_max = valid_raw_is_weight.max().detach().item() + else: + # fallback if no valid tokens (rare edge case) + raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 + stats = { + "infer_correction/reject_frac": reject_frac, + "infer_correction/clip_frac": clip_frac, + "infer_correction/clip_and_reject_frac": clip_and_reject_frac, + "infer_correction/clip_or_reject_frac": clip_or_reject_frac, + "infer_correction/seq_reject_frac": seq_reject_frac, + "infer_correction/inferkl_orig": inferkl_orig_agg, + "infer_correction/inferkl_final": inferkl_final_agg, + "infer_correction/raw_is_mean": raw_is_mean, + "infer_correction/raw_is_std": raw_is_std, + "infer_correction/raw_is_min": raw_is_min, + "infer_correction/raw_is_max": raw_is_max, + } + return pg_loss, final_mask, stats @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): if self.worker_config.offload_nccl: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 477438595..970d13b4d 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -30,6 +30,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] @@ -76,7 +77,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -127,6 +134,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) + pg_metrics.updata(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 19d0c66de..7e0ad08c1 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -17,6 +17,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -53,34 +54,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl" ) - train_infer_ratio = (old_log_probs - infer_log_probs).exp() - train_infer_diff = old_log_probs.exp() - infer_log_probs.exp() - train_infer_ratio_seq = masked_mean(old_log_probs - infer_log_probs, response_mask, dim=-1).exp().unsqueeze(-1).expand_as(train_infer_ratio) - train_infer_diff_seq = masked_mean(old_log_probs.exp() - infer_log_probs.exp(), response_mask, dim=-1).unsqueeze(-1).expand_as(train_infer_diff) - - train_infer_ratio_mask_mean = 1.0 - train_infer_diff_mask_mean = 1.0 - train_infer_ratio_seq_mask_mean = 1.0 - train_infer_diff_seq_mask_mean = 1.0 - - if self.pipeline_config.train_infer_ratio_mask: - train_infer_ratio_mask = (train_infer_ratio <= self.pipeline_config.train_infer_ratio_threshold_high).float() * (train_infer_ratio >= self.pipeline_config.train_infer_ratio_threshold_low).float() - train_infer_ratio_mask_mean = masked_mean(train_infer_ratio_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_mask - if self.pipeline_config.train_infer_diff_mask: - train_infer_diff_mask = (train_infer_diff <= self.pipeline_config.train_infer_diff_threshold_high).float() * (train_infer_diff >= self.pipeline_config.train_infer_diff_threshold_low).float() - train_infer_diff_mask_mean = masked_mean(train_infer_diff_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_mask - - if self.pipeline_config.train_infer_ratio_seq_mask: - train_infer_ratio_seq_mask = (train_infer_ratio_seq <= self.pipeline_config.train_infer_ratio_seq_threshold_high).float() * (train_infer_ratio_seq >= self.pipeline_config.train_infer_ratio_seq_threshold_low).float() - train_infer_ratio_seq_mask_mean = masked_mean(train_infer_ratio_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_seq_mask - if self.pipeline_config.train_infer_diff_seq_mask: - train_infer_diff_seq_mask = (train_infer_diff_seq <= self.pipeline_config.train_infer_diff_seq_threshold_high).float() * (train_infer_diff_seq >= self.pipeline_config.train_infer_diff_seq_threshold_low).float() - train_infer_diff_seq_mask_mean = masked_mean(train_infer_diff_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_seq_mask - if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() elif self.pipeline_config.importance_sampling == "seq": @@ -98,11 +71,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - - if self.pipeline_config.use_rollout_importance_sampling_ratio: - rollout_importance_sampling_clip = (train_infer_ratio > self.pipeline_config.rollout_importance_sampling_ratio_upper_bound).float() - loss = train_infer_ratio.clamp(0, self.pipeline_config.rollout_importance_sampling_ratio_upper_bound) * loss - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -149,16 +124,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_scale=loss_scale) total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() - - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_ratio_mask_mean": train_infer_ratio_mask_mean, - "actor/train_infer_diff_mask_mean": train_infer_diff_mask_mean, - "actor/train_infer_ratio_seq_mask_mean": train_infer_ratio_seq_mask_mean, - "actor/train_infer_diff_seq_mask_mean": train_infer_diff_seq_mask_mean, - } - + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -170,9 +136,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode, loss_scale=loss_scale).detach().item(), } - if self.pipeline_config.use_rollout_importance_sampling_ratio: - loss_metric["actor/rollout_importance_sampling_clip"] = rollout_importance_sampling_clip.mean().detach().item() - pg_metrics = { "actor/pg_loss": original_pg_loss.detach().item(), "actor/weighted_pg_loss": weighted_pg_loss.detach().item(), @@ -190,9 +153,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/sample_weights_max": sample_weights.max().detach().item(), **metrics, **loss_metric, - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) return total_loss, pg_metrics def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): From e98c4cea7a5d56649711dc5e08db1078bd001fd8 Mon Sep 17 00:00:00 2001 From: millioniron Date: Sun, 7 Dec 2025 14:43:05 +0800 Subject: [PATCH 56/58] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E8=AE=A2?= =?UTF-8?q?=E4=BA=86=E6=95=B4=E4=B8=AA=E7=9A=84=E6=8E=92=E7=89=88=EF=BC=8C?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=E5=87=BA=E4=BA=86=E4=B8=80=E4=B8=AA=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E6=9B=B4=E5=8A=A0?= =?UTF-8?q?=E8=87=AA=E7=94=B1=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rlvr_infer_correction_config.yaml | 12 +- roll/configs/base_config.py | 6 +- roll/configs/generating_args.py | 2 +- roll/pipeline/base_worker.py | 169 +------- roll/pipeline/rlvr/actor_pg_worker.py | 23 +- roll/pipeline/rlvr/actor_worker.py | 22 +- roll/utils/infer_correction.py | 380 ++++++++++++++++++ 7 files changed, 439 insertions(+), 175 deletions(-) create mode 100644 roll/utils/infer_correction.py diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml index e2cbac8a7..1bbe3c8e8 100644 --- a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -80,8 +80,8 @@ infer_is_threshold_min: 0.0 infer_is_threshold_max: 2.0 # 1.5~5.0 enable_token_reject: false -infer_token_rs_threshold_min: 0.0 -infer_token_rs_threshold_max: 2.0 # 2~10 +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 enable_catastrophic_reject: false infer_catastrophic_threshold: 1e-4 @@ -89,12 +89,12 @@ infer_catastrophic_threshold: 1e-4 enable_seq_reject: None # enable_seq_reject: sequence -# infer_seq_rs_threshold_min: 0.1 -# infer_seq_rs_threshold_max: 10 +# infer_seq_mask_threshold_min: 0.1 +# infer_seq_mask_threshold_max: 10 # enable_seq_reject: geometric -# infer_seq_rs_threshold_min: 0.999 -# infer_seq_rs_threshold_max: 1.001 +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 validation: data_args: diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 8dc75d3f3..a6ea37f09 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -407,9 +407,9 @@ class PPOConfig(BaseConfig): default=False, metadata={"help": "Whether to apply importance sampling correction during inference."} ) - infer_is_mode: Literal["token", "sequence", "none"] = field( - default="token", - metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + infer_is_mode: Literal["token", "sequence", "None"] = field( + default="None", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'None' (no IS weighting)."} ) # Clipping thresholds (used in IS weighting) infer_is_threshold_min: float = field( diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 0822614ba..5d1f94f95 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -59,7 +59,7 @@ class GeneratingArguments: metadata={"help": "Whether to include the stop strings in output text."}, ) logprobs: Optional[int] = field( - default=None, + default=0, metadata={"help": "Whether return infer log-prob."}, ) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 7148752a1..e7cb65389 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -25,10 +25,11 @@ postprocess_generate, GenerateRequestType, agg_loss, - masked_sum + masked_sum, ) from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType +from roll.utils.infer_correction import InferCorrectionHandler from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -309,10 +310,10 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) + ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] infer_log_probs = data.batch.get("infer_logprobs", None) - log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) @@ -328,19 +329,26 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - pg_loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=pg_loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_penalty="k3") kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" ) @@ -384,150 +392,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_metrics.update(infer_stats) return total_loss, pg_metrics - - - def infer_correction( - self, - old_log_probs: torch.Tensor, - infer_log_probs: torch.Tensor, - response_mask: torch.Tensor, - pg_loss: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, dict]: - """ - 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 - 返回更新后的 pg_loss、mask 和详细统计信息。 - """ - # Step 0: Shape alignment - if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: - infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] - assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ - f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" - # Step 1: Compute log-ratio and ratio - log_ratio = old_log_probs - infer_log_probs # [B, T] - ratio = torch.exp(log_ratio) # [B, T] - # Step 2: Apply IS weighting strategy (optional) - if self.pipeline_config.infer_is_mode == "token": - raw_is_weight = ratio - elif self.pipeline_config.infer_is_mode == "sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] - raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] - elif self.pipeline_config.infer_is_mode in (None, "none", ""): - raw_is_weight = torch.ones_like(ratio) - else: - raw_is_weight = torch.ones_like(ratio) - # Clamp to get final is_weight (used for loss) - is_weight = raw_is_weight.clamp( - min=self.pipeline_config.infer_is_threshold_min, - max=self.pipeline_config.infer_is_threshold_max - ).detach() - # Step 3: Build rejection mask - original_valid = response_mask > 0.5 # [B, T], bool - keep_mask = original_valid.clone() - # (a) Token-level ratio reject - if getattr(self.pipeline_config, 'enable_token_reject', False): - ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max - ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min - token_reject = ratio_too_high | ratio_too_low - keep_mask = keep_mask & (~token_reject) - # (b) Catastrophic reject - if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): - catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid - has_catastrophic = catastrophic.any(dim=-1, keepdim=True) - keep_mask = keep_mask & (~has_catastrophic) - # (c) Sequence-level reject - if getattr(self.pipeline_config, 'enable_seq_reject', False): - if self.pipeline_config.enable_seq_reject=="sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_sum) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - elif self.pipeline_config.enable_seq_reject=="geometric": - log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_mean) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - # final_mask = keep_mask.float() - final_mask = keep_mask - # Step 4: Reweight policy loss - pg_loss = pg_loss * is_weight - # Step 5: Compute detailed stats over original_valid tokens - # Rejected mask - rejected_mask = original_valid & (~keep_mask) # [B, T] - # Clipped mask: only meaningful if IS weighting is active - if self.pipeline_config.infer_is_mode in ("token", "sequence"): - clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid - clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid - clipped_mask = clipped_low | clipped_high # [B, T] - else: - clipped_mask = torch.zeros_like(original_valid) # no clipping - # Compute fractions - def _compute_frac(mask_tensor): - return agg_loss( - loss_mat=mask_tensor.float(), - loss_mask=response_mask, - loss_agg_mode="token-mean" # force token-wise average - ).detach().item() - clip_frac = _compute_frac(clipped_mask) - reject_frac = _compute_frac(rejected_mask) - clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) - clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) - # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) - seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token - seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] - total_valid_seqs = seq_has_valid.sum().item() - rejected_seqs = seq_completely_rejected.sum().item() - seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 - - ### kl metric - inferkl_orig = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=response_mask, # ← original mask - kl_penalty="kl" - ) - inferkl_final = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=final_mask, # ← after rejection - kl_penalty="kl" - ) - inferkl_orig_agg = agg_loss( - loss_mat=inferkl_orig, - loss_mask=response_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - inferkl_final_agg = agg_loss( - loss_mat=inferkl_final, - loss_mask=final_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] - if valid_raw_is_weight.numel() > 0: - raw_is_mean = valid_raw_is_weight.mean().detach().item() - raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() - raw_is_min = valid_raw_is_weight.min().detach().item() - raw_is_max = valid_raw_is_weight.max().detach().item() - else: - # fallback if no valid tokens (rare edge case) - raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 - stats = { - "infer_correction/reject_frac": reject_frac, - "infer_correction/clip_frac": clip_frac, - "infer_correction/clip_and_reject_frac": clip_and_reject_frac, - "infer_correction/clip_or_reject_frac": clip_or_reject_frac, - "infer_correction/seq_reject_frac": seq_reject_frac, - "infer_correction/inferkl_orig": inferkl_orig_agg, - "infer_correction/inferkl_final": inferkl_final_agg, - "infer_correction/raw_is_mean": raw_is_mean, - "infer_correction/raw_is_std": raw_is_std, - "infer_correction/raw_is_min": raw_is_min, - "infer_correction/raw_is_max": raw_is_max, - } - return pg_loss, final_mask, stats + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): if self.worker_config.offload_nccl: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 970d13b4d..e83ed0a9b 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl from roll.pipeline.rlvr.actor_worker import ActorWorker +from roll.utils.infer_correction import InferCorrectionHandler class ActorPGWorker(ActorWorker): @@ -30,10 +31,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() - infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", None) advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -77,13 +79,19 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -134,7 +142,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) - pg_metrics.updata(infer_stats) + + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 7e0ad08c1..da34d3a38 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -17,7 +18,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] - + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -71,13 +72,21 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + + + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -124,6 +133,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_scale=loss_scale) total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), diff --git a/roll/utils/infer_correction.py b/roll/utils/infer_correction.py new file mode 100644 index 000000000..215203081 --- /dev/null +++ b/roll/utils/infer_correction.py @@ -0,0 +1,380 @@ +from typing import Literal, Optional, Tuple, Dict, Any +import torch + +class StatsCollector: + """统一收集诊断指标的类""" + def __init__(self, prefix: str = "infer_correction"): + self.prefix = prefix + self.stats: Dict[str, Any] = {} + self.tensor_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} + + def add(self, name: str, value: Any): + """添加标量指标""" + self.stats[f"{self.prefix}/{name}"] = value.item() if torch.is_tensor(value) else value + + def add_tensor_stat(self, name: str, tensor: torch.Tensor, mask: torch.Tensor): + """添加张量统计指标(延迟计算)""" + self.tensor_stats[name] = (tensor, mask) + + def compute_tensor_stats(self): + """严格遵循原始代码的数据移动策略""" + for name, (tensor, mask) in self.tensor_stats.items(): + # 1. 确保在同一设备上 + if tensor.device != mask.device: + mask = mask.to(tensor.device) + + # 2. 直接在原始代码风格中计算:先筛选,再移动到CPU + mask=mask.bool() + valid = tensor[mask] + + # 3. 严格按照原始代码逻辑处理 + if valid.numel() > 0: + # 关键:先detach()再item(),确保在CPU上计算 + valid_cpu = valid.detach().cpu() + self.add(f"{name}_mean", valid_cpu.mean().item()) + self.add(f"{name}_std", valid_cpu.std(unbiased=False).item() if valid_cpu.numel() > 1 else 0.0) + self.add(f"{name}_min", valid_cpu.min().item()) + self.add(f"{name}_max", valid_cpu.max().item()) + else: + self.add(f"{name}_mean", 0.0) + self.add(f"{name}_std", 0.0) + self.add(f"{name}_min", 0.0) + self.add(f"{name}_max", 0.0) + + self.tensor_stats.clear() + + def get_metrics(self) -> Dict[str, float]: + """获取所有指标""" + return self.stats.copy() + +class InferCorrectionHandler: + """处理重要性采样校正和样本拒绝的核心类""" + def __init__(self, pipeline_config: "PPOConfig"): + self.pipeline_config = pipeline_config + self.stats = StatsCollector() + + def __call__( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: + """ + 主入口:执行重要性采样校正和样本拒绝 + + Args: + old_log_probs: 历史策略的log概率 [B, T] + infer_log_probs: 生成时策略的log概率 [B, T] + response_mask: 有效token掩码 [B, T] + pg_loss: 原始策略梯度损失 [B, T] + + Returns: + weighted_loss: 重加权后的损失 + final_mask: 最终保留的token掩码 + metrics: 诊断指标字典 + """ + # 1. 对齐形状 + infer_log_probs = self._align_shapes(old_log_probs, infer_log_probs, response_mask) + + # 2. 计算IS权重 + ratio, raw_is_weight, is_weight = self._compute_is_weights(old_log_probs, infer_log_probs, response_mask) + + # 3. 收集基础统计 + self._collect_base_stats(ratio, response_mask) + + # 4. 应用拒绝策略 + keep_mask = response_mask.clone() + keep_mask = self._apply_token_rejection(ratio, keep_mask) + keep_mask = self._apply_catastrophic_rejection(ratio, keep_mask, response_mask) + keep_mask = self._apply_sequence_rejection(ratio, keep_mask, response_mask) + + # 5. 计算拒绝统计 + self._collect_rejection_stats(ratio, raw_is_weight, keep_mask, response_mask) + + # 6. 重加权损失 + weighted_loss = pg_loss * is_weight + + # 7. 计算KL指标 + self._compute_kl_metrics(old_log_probs, infer_log_probs, keep_mask, response_mask) + + # 8. 批量计算张量统计 + self.stats.compute_tensor_stats() + + return weighted_loss, keep_mask, self.stats.get_metrics() + + def _align_shapes( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """对齐log概率张量形状""" + if infer_log_probs.shape[1] == old_log_probs.shape[1] + 1: + infer_log_probs = infer_log_probs[:, 1:] + + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, ( + f"Shape mismatch: old_log_probs {old_log_probs.shape}, " + f"infer_log_probs {infer_log_probs.shape}, " + f"response_mask {response_mask.shape}" + ) + return infer_log_probs + + def _compute_is_weights( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 计算重要性采样权重 + + Returns: + ratio: 原始重要性比率 [B, T] + raw_is_weight: 未裁剪的IS权重 [B, T] + is_weight: 裁剪后的IS权重 [B, T] + """ + log_ratio = old_log_probs - infer_log_probs + ratio = torch.exp(log_ratio) + + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS:使用序列总log-ratio + log_ratio_sum = self._masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) + seq_ratio = torch.exp(log_ratio_sum) + raw_is_weight = seq_ratio.expand_as(ratio) + # 收集序列级统计 + self.stats.add_tensor_stat("seq_ratio", seq_ratio.squeeze(-1), torch.ones_like(seq_ratio.squeeze(-1), dtype=torch.bool)) + else: # "None" or any other value + raw_is_weight = torch.ones_like(ratio) + + # 裁剪IS权重 + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + + return ratio, raw_is_weight, is_weight + + def _collect_base_stats(self, ratio: torch.Tensor, response_mask: torch.Tensor): + """收集基础统计指标""" + self.stats.add_tensor_stat("token_ratio", ratio, response_mask) + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 1. 裁剪比例统计(现有代码) + clipped_low = ratio <= self.pipeline_config.infer_is_threshold_min + clipped_high = ratio >= self.pipeline_config.infer_is_threshold_max + clipped = clipped_low | clipped_high + self.stats.add("token_clip_low_frac", self._agg_loss(clipped_low.float(), response_mask)) + self.stats.add("token_clip_high_frac", self._agg_loss(clipped_high.float(), response_mask)) + self.stats.add("token_clip_frac", self._agg_loss(clipped.float(), response_mask)) + + # 2. 添加缺失的:裁剪后权重的分布统计 + if self.pipeline_config.infer_is_mode == "token": + # 重新计算裁剪后的权重 + is_weight = ratio.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ) + # 添加缺失的统计 + self.stats.add_tensor_stat("token_is_weight", is_weight, response_mask) + + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS权重已在_compute_is_weights中添加 + pass + + def _apply_token_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor + ) -> torch.Tensor: + """应用token级拒绝策略""" + if not self.pipeline_config.enable_token_reject: + return keep_mask + + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + + # 更新掩码:丢弃被拒绝的token + new_keep_mask = keep_mask & (~token_reject) + + # 收集统计 + self.stats.add("token_reject_low_frac", self._agg_loss(ratio_too_low.float(), keep_mask)) + self.stats.add("token_reject_high_frac", self._agg_loss(ratio_too_high.float(), keep_mask)) + + return new_keep_mask + + def _apply_catastrophic_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用灾难性拒绝策略""" + if not self.pipeline_config.enable_catastrophic_reject: + return keep_mask + + # 识别灾难性token + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & response_mask + + # 检查哪些序列包含灾难性token + seq_has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + + # 更新掩码:丢弃包含灾难性token的整个序列 + new_keep_mask = keep_mask & (~seq_has_catastrophic) + + # 收集统计 + catastrophic_token_frac = self._agg_loss(catastrophic.float(), response_mask) + self.stats.add("catastrophic_token_frac", catastrophic_token_frac) + + # 计算包含灾难性token的序列比例 + seq_has_valid = response_mask.any(dim=-1) + seq_has_catastrophic_flat = catastrophic.any(dim=-1) & seq_has_valid + catastrophic_seq_frac = ( + seq_has_catastrophic_flat.sum().float() / seq_has_valid.sum().float() + if seq_has_valid.sum() > 0 else 0.0 + ) + self.stats.add("catastrophic_seq_frac", catastrophic_seq_frac) + + return new_keep_mask + + def _apply_sequence_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用序列级拒绝策略""" + if self.pipeline_config.enable_seq_reject in (None, "None", "none"): + return keep_mask + + # 计算序列级比率 + if self.pipeline_config.enable_seq_reject == "sequence": + log_ratio_agg = self._masked_sum(torch.log(ratio), response_mask, dim=-1) + elif self.pipeline_config.enable_seq_reject == "geometric": + log_ratio_agg = self._masked_mean(torch.log(ratio), response_mask, dim=-1) + else: + return keep_mask + + seq_ratio = torch.exp(log_ratio_agg) + + # 识别要拒绝的序列 + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + + # 更新掩码 + new_keep_mask = keep_mask & (~seq_reject) + + # 收集统计 + seq_has_valid = response_mask.any(dim=-1) + total_valid_seqs = seq_has_valid.sum().item() + + seq_reject_low = seq_too_low & seq_has_valid + seq_reject_high = seq_too_high & seq_has_valid + + seq_reject_low_frac = seq_reject_low.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + seq_reject_high_frac = seq_reject_high.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + self.stats.add("seq_reject_low_frac", seq_reject_low_frac) + self.stats.add("seq_reject_high_frac", seq_reject_high_frac) + + return new_keep_mask + + def _collect_rejection_stats( + self, + ratio: torch.Tensor, + raw_is_weight: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """收集拒绝相关的统计指标""" + + # 计算被拒绝的token + rejected_mask = response_mask & (~keep_mask) + self.stats.add("reject_frac", self._agg_loss(rejected_mask.float(), response_mask)) + + + # 仅在序列拒绝启用时计算序列级拒绝率 + if self.pipeline_config.enable_seq_reject not in (None, "None", "none"): + seq_has_valid = response_mask.any(dim=-1) + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + self.stats.add("seq_reject_frac", seq_reject_frac) + else: + # 未启用时显式设为0.0 + self.stats.add("seq_reject_frac", 0.0) + + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 使用已计算的rejected_mask + clipped_mask = ((raw_is_weight <= self.pipeline_config.infer_is_threshold_min) | + (raw_is_weight >= self.pipeline_config.infer_is_threshold_max)) & response_mask + + clip_and_reject_frac = self._agg_loss((clipped_mask & rejected_mask).float(), response_mask) + clip_or_reject_frac = self._agg_loss((clipped_mask | rejected_mask).float(), response_mask) + + self.stats.add("token_clip_and_reject_frac", clip_and_reject_frac) + self.stats.add("token_clip_or_reject_frac", clip_or_reject_frac) + else: + # 关键:为未启用IS的情况提供默认值 + self.stats.add("token_clip_and_reject_frac", 0.0) + self.stats.add("token_clip_or_reject_frac", 0.0) + + def _compute_kl_metrics( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """计算KL散度指标""" + # 原始KL(所有有效token) + inferkl_orig = self._compute_approx_kl(infer_log_probs, old_log_probs, response_mask, kl_penalty="kl") + inferkl_orig_agg = self._agg_loss(inferkl_orig, response_mask) + self.stats.add("inferkl", inferkl_orig_agg) + + # 拒绝后KL(仅保留的token) + inferkl_final = self._compute_approx_kl(infer_log_probs, old_log_probs, keep_mask, kl_penalty="kl") + inferkl_final_agg = self._agg_loss(inferkl_final, keep_mask) + self.stats.add("inferkl_reject", inferkl_final_agg) + + # --- 辅助方法(使用已有工具函数)--- + def _compute_approx_kl( + self, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: torch.Tensor, + kl_penalty: str = "kl" + ) -> torch.Tensor: + """使用已有的compute_approx_kl函数计算近似KL散度""" + from roll.utils.functionals import compute_approx_kl + return compute_approx_kl( + log_probs=log_probs, + log_probs_base=log_probs_base, + action_mask=action_mask, + kl_penalty=kl_penalty + ) + + def _agg_loss(self, loss_mat: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """使用已有的agg_loss函数聚合损失""" + from roll.utils.functionals import agg_loss + return agg_loss( + loss_mat=loss_mat, + loss_mask=loss_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ) + + def _masked_sum(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_sum函数在掩码区域求和""" + from roll.utils.functionals import masked_sum + return masked_sum(tensor, mask, dim=dim) + + def _masked_mean(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_mean函数在掩码区域计算均值""" + from roll.utils.functionals import masked_mean + return masked_mean(tensor, mask, dim=dim) \ No newline at end of file From c3d1121e10cd07cb7bd068ac3e4afc3c867cd533 Mon Sep 17 00:00:00 2001 From: millioniron Date: Mon, 8 Dec 2025 20:04:15 +0800 Subject: [PATCH 57/58] modified: roll/pipeline/agentic/env_manager/step_env_manager.py --- roll/pipeline/agentic/env_manager/step_env_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 7db4d0df0..a1d7cbf28 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -79,7 +79,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) - infer_logprobs = [] if "infer_logprobs" in history: infer_logprobs = [0] * len(history["prompt_ids"]) + history["infer_logprobs"] @@ -88,7 +87,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): prompt_mask = torch.tensor(prompt_masks, dtype=torch.bool).unsqueeze(0) score_tensor = torch.tensor([0] * len(token_ids), dtype=torch.float).unsqueeze(0) score_tensor[0][-1] = history['reward'] - infer_logprobs = history["infer_logprobs"].flatten().unsqueeze(0) position_ids = attention_mask.cumsum(dim=-1) input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) @@ -106,7 +104,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, - "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ From 515bf39e15de7c249cc0edaf87ead8a7c24bc087 Mon Sep 17 00:00:00 2001 From: millioniron Date: Mon, 8 Dec 2025 20:28:40 +0800 Subject: [PATCH 58/58] =?UTF-8?q?=E5=8E=BB=E6=8E=89=E5=8E=9F=E6=9C=AC?= =?UTF-8?q?=E5=AE=98=E6=96=B9=E7=9A=84train-infer=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roll/pipeline/agentic/agentic_actor_worker.py | 31 ++++++++++++------- roll/pipeline/rlvr/rlvr_config.py | 16 ---------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index 75510c675..e97fef5ba 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -14,6 +15,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): output_tensor: torch.Tensor, model.forward()的输出Tensor """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] @@ -29,8 +31,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): else: ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip @@ -41,11 +41,23 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -70,11 +82,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ) total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - } - pg_metrics = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -91,8 +98,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) + return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index ba51d62ef..d053e2e9a 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -149,22 +149,6 @@ class RLVRConfig(PPOConfig): importance_sampling: Literal["token", "seq"] = ( field(default="token", metadata={"help": "policy importance sampling"}) ) - use_rollout_importance_sampling_ratio: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level loss weight"}) - rollout_importance_sampling_ratio_upper_bound: float = field(default=1.2) - - train_infer_ratio_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level response mask"}) - train_infer_ratio_threshold_low: float = field(default=0.8) - train_infer_ratio_threshold_high: float = field(default=1.2) - train_infer_diff_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as token-level response mask"}) - train_infer_diff_threshold_low: float = field(default=-0.2) - train_infer_diff_threshold_high: float = field(default=0.2) - - train_infer_ratio_seq_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as sequence-level response mask"}) - train_infer_ratio_seq_threshold_low: float = field(default=0.8) - train_infer_ratio_seq_threshold_high: float = field(default=1.2) - train_infer_diff_seq_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as sequence-level response mask"}) - train_infer_diff_seq_threshold_low: float = field(default=-0.2) - train_infer_diff_seq_threshold_high: float = field(default=0.2) val_greedy: bool = field(default=False, metadata={"help": "Use greedy for validation"}) val_n_sample: int = field(default=1, metadata={"help": "Number of samples for validation"})