From 22b27853f334ac6790bf64374f643757cd73933c Mon Sep 17 00:00:00 2001 From: millioniron Date: Wed, 3 Dec 2025 14:08:18 +0800 Subject: [PATCH 1/7] new file: docs_roll/docs/User Guides/Configuration/infer_correction.md new file: examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file: examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file: examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file: examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh modified: roll/configs/base_config.py modified: roll/configs/generating_args.py modified: roll/distributed/scheduler/generate_scheduler.py modified: roll/pipeline/agentic/env_manager/step_env_manager.py modified: roll/pipeline/agentic/env_manager/traj_env_manager.py modified: roll/pipeline/base_worker.py modified: roll/pipeline/rlvr/actor_pg_worker.py modified: roll/pipeline/rlvr/actor_worker.py --- .../Configuration/infer_correction.md | 142 ++++++++++ .../agentic_webshop_infer_correction.yaml | 183 ++++++++++++ .../rlvr_infer_correction_config.yaml | 265 ++++++++++++++++++ .../run_agentic_pipeline_webshop.sh | 10 + .../run_rlvr_pipeline.sh | 5 + roll/configs/base_config.py | 58 +++- roll/configs/generating_args.py | 4 + .../scheduler/generate_scheduler.py | 1 + .../agentic/env_manager/step_env_manager.py | 4 +- .../agentic/env_manager/traj_env_manager.py | 11 +- roll/pipeline/base_worker.py | 165 ++++++++++- roll/pipeline/rlvr/actor_pg_worker.py | 10 +- roll/pipeline/rlvr/actor_worker.py | 23 +- 13 files changed, 857 insertions(+), 24 deletions(-) create mode 100644 docs_roll/docs/User Guides/Configuration/infer_correction.md create mode 100644 examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml create mode 100644 examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml create mode 100644 examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh create mode 100644 examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh diff --git a/docs_roll/docs/User Guides/Configuration/infer_correction.md b/docs_roll/docs/User Guides/Configuration/infer_correction.md new file mode 100644 index 000000000..ecdce6325 --- /dev/null +++ b/docs_roll/docs/User Guides/Configuration/infer_correction.md @@ -0,0 +1,142 @@ +# 训推差异修复 + + +## 简介 + +训推差异是由于RL训练过程中,训练器和生成器之间由于后端不同(vLLM vs SGLang vs FSDP vs Megatron),精度不同(FP8 vs FP16 vs BF16 vs FP32),形成了一种类似off-policy gap,会导致训练不稳定和策略崩溃。 + + +## 实现原理 + + +修复训推差异大致可分为两种方法(1)对训练器和生成器进行策略修正(2)使用infer_log_probs直接代替old_log_probs(trainer)进行PPO ratio计算。第二种方案比较直接,我们着重说明第一种方法。 + +### 对训练器和生成器进行策略修正 + +### IS权重矫正 +通过对训练器(old_log_probs)和生成器(infer_log_prob)之间进行重要性采样矫正,弥合训推差异。与off-policy算法类似,IS权重矫正可以区分token级别和sequence级别,只能选择一个。 + +### MASK过滤掉不符合条件的样本 +与IS权重修正不同的是,此方法对于超过阈值的样本直接进行mask遮掩,过滤掉不符合的样本。涉及的方法有(1)token级别:过滤掉不符合条件的token(2)灾难性token:过滤掉出现灾难性严重偏差的token的句子样本(3)sequence级别:对sequence进行IS计算,过滤掉不符合的句子样本(4)sequence级别,使用几何平均来计算IS权重,但指标也更为敏感 + + +## 关键参数配置 + +生成器是否返回infer_log_probs + +GeneratingArguments: + +```yaml +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: fp16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 +``` +当logprobs:大于0时,返回infer_log_probs + + +------- + +参数的配置在PPOConfig和中,关键配置参数如下: + +```yaml +infer_correction: true + +infer_is_mode: token # 可选 token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: true +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: true +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: sequence 可选None sequence geometric +infer_seq_mask_threshold_min: 0.1 +infer_seq_mask_threshold_max: 10 +``` + + +### infer_correction +- **含义**:控制是否启用训推差异修复机制。若启用,系统将使用 `infer_log_probs` 对策略梯度进行矫正。 +- **默认值**:`false` + +### infer_is_mode +- **含义**:指定重要性采样(IS)权重的计算粒度。 +- **可选值**: + - `"token"`:每个 token 独立计算 IS 权重 + - `"sequence"`:基于整个 response 序列的 log-ratio 求和后广播至所有 token + - `"none"`(或未设置):不应用 IS 加权,权重恒为 1 +- **默认值**:若未设置,默认为 `"token"` +- **注意**:不可同时使用多种模式,仅能选择其一。 + +### infer_is_threshold_min +- **含义**:IS 权重的下限阈值,用于裁剪过小的权重以控制方差。 +- **默认值**:`0.0` +- **建议**:通常保留为 `0.0`,以保持无偏性下界 + +### infer_is_threshold_max +- **含义**:IS 权重的上限阈值,防止极端大的权重主导梯度。 +- **默认值**:`2.0` +- **建议**:`"token"`级别推荐为 `1.5 ~ 5.0` `"sequence"`级别推荐为2.0 - 10.0 + +### enable_token_reject +- **含义**:是否启用 token 级别的样本拒绝机制。 +- **默认值**:`false` +- **作用**:结合 `infer_token_mask_threshold_min/max`,mask 掉 IS ratio 超出合法区间的 token。 + +### infer_token_mask_threshold_min +- **含义**:token 级 IS ratio(`old_log_probs / infer_log_probs` 的指数)的下限。 +- **默认值**:`0.0` +- **典型值**:`0.0`通常可设为1/max + +### infer_token_mask_threshold_max +- **含义**:遮掩token 级 IS ratio 的上限。 +- **默认值**:`2.0` +- **典型范围**:`1.5 ~ 5.0` + +### enable_catastrophic_reject +- **含义**:是否启用“灾难性偏差”检测并拒绝整句样本。 +- **默认值**:`false` +- **触发条件**:只要序列中存在任一 token 满足 `ratio < infer_catastrophic_threshold`,则整句被 mask。 + +### infer_catastrophic_threshold +- **含义**:灾难性拒绝的 ratio 下限阈值。 +- **默认值**:`1e-4` +- **解释**:当 `infer_log_probs` 远大于 `old_log_probs`(即生成器过于“自信”),导致 `ratio = exp(old - infer)` 极小 + +### enable_seq_reject +- **含义**:是否启用序列级别的拒绝机制,以及使用何种聚合方式。 +- **可选值**: + - `null` / `false`:禁用 + - `"sequence"`:使用 log-ratio **求和** 计算序列 IS ratio + - `"geometric"`:使用 log-ratio **平均**(等价于几何平均概率)计算序列 IS ratio +- **默认值**:`null` + +### infer_seq_mask_threshold_min +- **含义**:遮掩序列级 IS ratio 的下限。 +- **默认值**:`0.1` +- **典型值**:通常可设为1/max,当使用`"geometric"`时,最好强制设为1/max + + +### infer_seq_mask_threshold_max +- **含义**:遮掩序列级 IS ratio 的上限。 +- **默认值**:`10.0` +- **典型范围**:当使用`"sequence"`时,推荐`2.0 ~ 10.0`,但随着长度增加可适当放宽。当使用`"geometric"`时,推荐设置为1.0001 - 1.001 + + + +## 使用建议 + +1. 通常情况下,old_log_prob << infer_log_porb, 阈值的下限就比较重要了。并不建议使用sequence级别的IS或MASK + diff --git a/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file mode 100644 index 000000000..78fefee47 --- /dev/null +++ b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml @@ -0,0 +1,183 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_webshop_infer_correction" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_webshop +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +#track_with: swanlab +#tracker_kwargs: +# login_kwargs: +# api_key: your_api_key +# project: roll-agentic +# logdir: debug +# experiment_name: ${exp_name} +# tags: +# - roll +# - agentic +# - debug + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_webshop + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 64 +val_batch_size: 64 +sequence_length: 8192 + +reward_clip: 20 +advantage_clip: 0.2 # 0.1-0.3 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +max_grad_norm: 1.0 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +# infer correction + +infer_correction: true + +infer_is_mode: token #token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + + +enable_seq_reject: None # None sequence geometric +infer_seq_mask_threshold_min: 0.999 +infer_seq_mask_threshold_max: 1.001 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 8 + warmup_steps: 10 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + max_grad_norm: ${max_grad_norm} + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 1024 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + logprobs: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: sglang + strategy_config: + mem_fraction_static: 0.85 + load_format: auto + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.05 + num_env_groups: 8 + group_size: 8 + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + num_env_groups: 64 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [64] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +custom_envs: + WebShopEnv: + ${custom_env.WebShopEnv} \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file mode 100644 index 000000000..e2cbac8a7 --- /dev/null +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -0,0 +1,265 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-infer-correction-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + + +# infer correction +infer_correction: true + +infer_is_mode: token +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_rs_threshold_min: 0.0 +infer_token_rs_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: None + +# enable_seq_reject: sequence +# infer_seq_rs_threshold_min: 0.1 +# infer_seq_rs_threshold_max: 10 + +# enable_seq_reject: geometric +# infer_seq_rs_threshold_min: 0.999 +# infer_seq_rs_threshold_max: 1.001 + +validation: + data_args: + template: qwen2.5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2.5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 + data_args: + template: qwen2.5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2.5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + world_size: 8 + infer_batch_size: 1 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2.5 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file mode 100644 index 000000000..ff1a59402 --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Run `git submodule update --init --recursive` to init submodules before run this script. +set +x + + +pip install -r third_party/webshop-minimal/requirements.txt --trusted-host mirrors.aliyun.com --index-url https://mirrors.aliyun.com/pypi/simple/ +python -m spacy download en_core_web_sm + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_agentic_pipeline.py --config_path $CONFIG_PATH --config_name agentic_webshop_infer_correction diff --git a/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh new file mode 100644 index 000000000..6874f4f5e --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_infer_correction_config \ No newline at end of file diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 350f69cc5..617f672be 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -382,7 +382,63 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) - + + # trainer&rollout mismatch + infer_correction: bool = field( + default=False, + metadata={"help": "Whether to apply importance sampling correction during inference."} + ) + infer_is_mode: Literal["token", "sequence", "none"] = field( + default="token", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + ) + # Clipping thresholds (used in IS weighting) + infer_is_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum threshold for IS weight clipping. Recommended 0.0 for unbiased estimation."} + ) + infer_is_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum threshold for IS weight clipping."} + ) + # Token-level rejection + enable_token_reject: bool = field( + default=False, + metadata={"help": "Enable token-level rejection based on IS ratio thresholds."} + ) + infer_token_mask_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum IS ratio threshold for token rejection."} + ) + infer_token_mask_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum IS ratio threshold for token rejection."} + ) + # Catastrophic rejection (reject entire sequence if any token ratio is too small) + enable_catastrophic_reject: bool = field( + default=False, + metadata={"help": "Enable catastrophic rejection: reject entire sequence if any valid token has IS ratio below threshold."} + ) + infer_catastrophic_threshold: float = field( + default=1e-4, + metadata={"help": "Threshold below which a token triggers catastrophic rejection of its sequence."} + ) + # Sequence-level rejection + enable_seq_reject: Optional[Literal["sequence", "geometric",'None']] = field( + default=None, + metadata={"help": "Enable sequence-level rejection: 'sequence' uses sum of log-ratios, 'geometric' uses mean. None disables."} + ) + infer_seq_mask_threshold_min: float = field( + default=0.1, + metadata={"help": "Minimum IS ratio threshold for sequence rejection."} + ) + infer_seq_mask_threshold_max: float = field( + default=10.0, + metadata={"help": "Maximum IS ratio threshold for sequence rejection (typically larger than token-level)."} + ) + + + def __post_init__(self): super().__post_init__() diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 059aff4a9..cf09e2f21 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -58,6 +58,10 @@ class GeneratingArguments: default=None, metadata={"help": "Whether to include the stop strings in output text."}, ) + logprobs: Optional[int] = field( + default=None, + metadata={"help": "Whether return infer log-prob."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/roll/distributed/scheduler/generate_scheduler.py b/roll/distributed/scheduler/generate_scheduler.py index c49b502c1..ab2630304 100644 --- a/roll/distributed/scheduler/generate_scheduler.py +++ b/roll/distributed/scheduler/generate_scheduler.py @@ -856,6 +856,7 @@ async def generate_one_request(self, data: DataProto): eos_token_id=eos_token_id, pad_token_id=pad_token_id, pad_to_seq_len=data.meta_info.get("pad_to_seq_len", True), + output_logprobs=response_data.meta_info.get("output_logprobs",None), ) request_repeat = data.repeat(repeat_times=len(output_tokens)) output.non_tensor_batch = request_repeat.non_tensor_batch diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 737910394..da2741982 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -85,6 +85,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): prompt_mask = torch.tensor(prompt_masks, dtype=torch.bool).unsqueeze(0) score_tensor = torch.tensor([0] * len(token_ids), dtype=torch.float).unsqueeze(0) score_tensor[0][-1] = history['reward'] + infer_logprobs = history["infer_logprobs"].flatten().unsqueeze(0) position_ids = attention_mask.cumsum(dim=-1) input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) @@ -93,7 +94,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) prompt_mask = pad_to_length(prompt_mask, length=self.pipeline_config.sequence_length, pad_value=0) score_tensor = pad_to_length(score_tensor, length=self.pipeline_config.sequence_length, pad_value=0) - + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) samples.append(DataProto( batch=TensorDict( { @@ -103,6 +104,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 050594636..bc65ff11a 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig from tensordict import TensorDict from transformers import PreTrainedTokenizer - +import json from roll.pipeline.agentic.llm_proxy import create_llm_proxy, BaseLLMProxy from roll.pipeline.agentic.env_manager.base_env_manager import RolloutCache, BaseEnvManager from roll.utils.env_action_limiter import get_global_limiter @@ -179,6 +179,7 @@ def step(self, llm_output: DataProto): self.rollout_cache.truncated = True self.rollout_cache.history[-1]['reward'] = reward self.rollout_cache.history[-1]['llm_response'] = responses[0] + self.rollout_cache.history[-1]['infer_logprobs'] = llm_output.batch['infer_logprobs'] if info is not None: self.rollout_cache.history[-1].update(info) @@ -293,13 +294,16 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): token_ids = [] prompt_masks = [] response_masks = [] + infer_logprobs = [] for items in self.rollout_cache.history: token_ids.extend(items["prompt_ids"]) token_ids.extend(items["response_ids"]) prompt_masks.extend([1] * len(items["prompt_ids"]) + [0] * len(items["response_ids"])) response_masks.extend([0] * len(items["prompt_ids"]) + [1] * len(items["response_ids"])) - + infer_logprobs.extend(items["infer_logprobs"].flatten().tolist()) + input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) @@ -316,6 +320,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "input_ids": input_ids, "attention_mask": attention_mask, "position_ids": position_ids, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]) @@ -323,6 +328,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): # TODO: move pad to pipeline input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) + infer_logprobs = pad_to_length(infer_logprobs, length=self.pipeline_config.sequence_length, pad_value=0) attention_mask = pad_to_length(attention_mask, length=self.pipeline_config.sequence_length, pad_value=0) position_ids = pad_to_length(position_ids, length=self.pipeline_config.sequence_length, pad_value=0) response_mask = pad_to_length(response_mask, length=self.pipeline_config.sequence_length, pad_value=0) @@ -336,6 +342,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }) lm_input.non_tensor_batch.update({ "env_ids": np.array([self.rollout_cache.env_id], dtype=object), diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 0987cfefa..f6e594fe7 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -2,7 +2,7 @@ import threading import time from typing import Union, Optional, Dict - +import numpy import ray import torch from codetiming import Timer @@ -25,6 +25,7 @@ postprocess_generate, GenerateRequestType, agg_loss, + masked_sum ) from roll.utils.offload_states import OffloadStateType from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching @@ -264,9 +265,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] + infer_log_probs = data.batch.get("infer_logprobs", None) log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] @@ -282,12 +285,18 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + pg_loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=pg_loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, - kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -329,9 +338,153 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), } + pg_metrics.update(infer_stats) return total_loss, pg_metrics - + + + def infer_correction( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 + 返回更新后的 pg_loss、mask 和详细统计信息。 + """ + # Step 0: Shape alignment + if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: + infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ + f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" + # Step 1: Compute log-ratio and ratio + log_ratio = old_log_probs - infer_log_probs # [B, T] + ratio = torch.exp(log_ratio) # [B, T] + # Step 2: Apply IS weighting strategy (optional) + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] + raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] + elif self.pipeline_config.infer_is_mode in (None, "none", ""): + raw_is_weight = torch.ones_like(ratio) + else: + raw_is_weight = torch.ones_like(ratio) + # Clamp to get final is_weight (used for loss) + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + # Step 3: Build rejection mask + original_valid = response_mask > 0.5 # [B, T], bool + keep_mask = original_valid.clone() + # (a) Token-level ratio reject + if getattr(self.pipeline_config, 'enable_token_reject', False): + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + keep_mask = keep_mask & (~token_reject) + # (b) Catastrophic reject + if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid + has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + keep_mask = keep_mask & (~has_catastrophic) + # (c) Sequence-level reject + if getattr(self.pipeline_config, 'enable_seq_reject', False): + if self.pipeline_config.enable_seq_reject=="sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_sum) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + elif self.pipeline_config.enable_seq_reject=="geometric": + log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_mean) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + # final_mask = keep_mask.float() + final_mask = keep_mask + # Step 4: Reweight policy loss + pg_loss = pg_loss * is_weight + # Step 5: Compute detailed stats over original_valid tokens + # Rejected mask + rejected_mask = original_valid & (~keep_mask) # [B, T] + # Clipped mask: only meaningful if IS weighting is active + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid + clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid + clipped_mask = clipped_low | clipped_high # [B, T] + else: + clipped_mask = torch.zeros_like(original_valid) # no clipping + # Compute fractions + def _compute_frac(mask_tensor): + return agg_loss( + loss_mat=mask_tensor.float(), + loss_mask=response_mask, + loss_agg_mode="token-mean" # force token-wise average + ).detach().item() + clip_frac = _compute_frac(clipped_mask) + reject_frac = _compute_frac(rejected_mask) + clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) + clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) + # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) + seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + ### kl metric + inferkl_orig = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=response_mask, # ← original mask + kl_penalty="kl" + ) + inferkl_final = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=final_mask, # ← after rejection + kl_penalty="kl" + ) + inferkl_orig_agg = agg_loss( + loss_mat=inferkl_orig, + loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + inferkl_final_agg = agg_loss( + loss_mat=inferkl_final, + loss_mask=final_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] + if valid_raw_is_weight.numel() > 0: + raw_is_mean = valid_raw_is_weight.mean().detach().item() + raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() + raw_is_min = valid_raw_is_weight.min().detach().item() + raw_is_max = valid_raw_is_weight.max().detach().item() + else: + # fallback if no valid tokens (rare edge case) + raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 + stats = { + "infer_correction/reject_frac": reject_frac, + "infer_correction/clip_frac": clip_frac, + "infer_correction/clip_and_reject_frac": clip_and_reject_frac, + "infer_correction/clip_or_reject_frac": clip_or_reject_frac, + "infer_correction/seq_reject_frac": seq_reject_frac, + "infer_correction/inferkl_orig": inferkl_orig_agg, + "infer_correction/inferkl_final": inferkl_final_agg, + "infer_correction/raw_is_mean": raw_is_mean, + "infer_correction/raw_is_std": raw_is_std, + "infer_correction/raw_is_min": raw_is_min, + "infer_correction/raw_is_max": raw_is_max, + } + return pg_loss, final_mask, stats @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): with Timer("do_checkpoint") as total_timer: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 813e582f5..b518acc2e 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -30,6 +30,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] @@ -76,7 +77,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -127,6 +134,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) + pg_metrics.updata(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index fbd7147bd..603ab4d7b 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -19,9 +19,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] - infer_log_probs = data.batch.get("infer_logprobs", old_log_probs) - infer_log_probs = infer_log_probs if len(infer_log_probs) > 0 else old_log_probs - + infer_log_probs = data.batch.get("infer_logprobs", None) + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -57,8 +56,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() elif self.pipeline_config.importance_sampling == "seq": log_ratio = log_probs - old_log_probs masked_log_ratio = masked_mean(log_ratio, final_response_mask, dim=-1) @@ -73,7 +70,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -121,11 +124,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - } - loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -154,9 +152,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/sample_weights_max": sample_weights.max().detach().item(), **metrics, **loss_metric, - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) return total_loss, pg_metrics def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): From 77c1e5d99c897e898ea1d9dc08a3b1e347f3679f Mon Sep 17 00:00:00 2001 From: millioniron Date: Sun, 7 Dec 2025 14:43:05 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E8=AE=A2?= =?UTF-8?q?=E4=BA=86=E6=95=B4=E4=B8=AA=E7=9A=84=E6=8E=92=E7=89=88=EF=BC=8C?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=E5=87=BA=E4=BA=86=E4=B8=80=E4=B8=AA=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E6=9B=B4=E5=8A=A0?= =?UTF-8?q?=E8=87=AA=E7=94=B1=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rlvr_infer_correction_config.yaml | 12 +- roll/configs/base_config.py | 6 +- roll/configs/generating_args.py | 2 +- roll/pipeline/base_worker.py | 169 +------- roll/pipeline/rlvr/actor_pg_worker.py | 22 +- roll/pipeline/rlvr/actor_worker.py | 22 +- roll/utils/infer_correction.py | 380 ++++++++++++++++++ 7 files changed, 438 insertions(+), 175 deletions(-) create mode 100644 roll/utils/infer_correction.py diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml index e2cbac8a7..1bbe3c8e8 100644 --- a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -80,8 +80,8 @@ infer_is_threshold_min: 0.0 infer_is_threshold_max: 2.0 # 1.5~5.0 enable_token_reject: false -infer_token_rs_threshold_min: 0.0 -infer_token_rs_threshold_max: 2.0 # 2~10 +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 enable_catastrophic_reject: false infer_catastrophic_threshold: 1e-4 @@ -89,12 +89,12 @@ infer_catastrophic_threshold: 1e-4 enable_seq_reject: None # enable_seq_reject: sequence -# infer_seq_rs_threshold_min: 0.1 -# infer_seq_rs_threshold_max: 10 +# infer_seq_mask_threshold_min: 0.1 +# infer_seq_mask_threshold_max: 10 # enable_seq_reject: geometric -# infer_seq_rs_threshold_min: 0.999 -# infer_seq_rs_threshold_max: 1.001 +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 validation: data_args: diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 617f672be..b2ef1ff28 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -388,9 +388,9 @@ class PPOConfig(BaseConfig): default=False, metadata={"help": "Whether to apply importance sampling correction during inference."} ) - infer_is_mode: Literal["token", "sequence", "none"] = field( - default="token", - metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + infer_is_mode: Literal["token", "sequence", "None"] = field( + default="None", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'None' (no IS weighting)."} ) # Clipping thresholds (used in IS weighting) infer_is_threshold_min: float = field( diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index cf09e2f21..8a540eb57 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -59,7 +59,7 @@ class GeneratingArguments: metadata={"help": "Whether to include the stop strings in output text."}, ) logprobs: Optional[int] = field( - default=None, + default=0, metadata={"help": "Whether return infer log-prob."}, ) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index f6e594fe7..62854c8b0 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -25,9 +25,10 @@ postprocess_generate, GenerateRequestType, agg_loss, - masked_sum + masked_sum, ) from roll.utils.offload_states import OffloadStateType +from roll.utils.infer_correction import InferCorrectionHandler from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -266,11 +267,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) + ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] advantages = data.batch["advantages"] infer_log_probs = data.batch.get("infer_logprobs", None) - log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) @@ -285,19 +286,26 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - pg_loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=pg_loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_penalty="k3") kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" ) @@ -341,150 +349,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_metrics.update(infer_stats) return total_loss, pg_metrics - - - def infer_correction( - self, - old_log_probs: torch.Tensor, - infer_log_probs: torch.Tensor, - response_mask: torch.Tensor, - pg_loss: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, dict]: - """ - 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 - 返回更新后的 pg_loss、mask 和详细统计信息。 - """ - # Step 0: Shape alignment - if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: - infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] - assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ - f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" - # Step 1: Compute log-ratio and ratio - log_ratio = old_log_probs - infer_log_probs # [B, T] - ratio = torch.exp(log_ratio) # [B, T] - # Step 2: Apply IS weighting strategy (optional) - if self.pipeline_config.infer_is_mode == "token": - raw_is_weight = ratio - elif self.pipeline_config.infer_is_mode == "sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] - raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] - elif self.pipeline_config.infer_is_mode in (None, "none", ""): - raw_is_weight = torch.ones_like(ratio) - else: - raw_is_weight = torch.ones_like(ratio) - # Clamp to get final is_weight (used for loss) - is_weight = raw_is_weight.clamp( - min=self.pipeline_config.infer_is_threshold_min, - max=self.pipeline_config.infer_is_threshold_max - ).detach() - # Step 3: Build rejection mask - original_valid = response_mask > 0.5 # [B, T], bool - keep_mask = original_valid.clone() - # (a) Token-level ratio reject - if getattr(self.pipeline_config, 'enable_token_reject', False): - ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max - ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min - token_reject = ratio_too_high | ratio_too_low - keep_mask = keep_mask & (~token_reject) - # (b) Catastrophic reject - if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): - catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid - has_catastrophic = catastrophic.any(dim=-1, keepdim=True) - keep_mask = keep_mask & (~has_catastrophic) - # (c) Sequence-level reject - if getattr(self.pipeline_config, 'enable_seq_reject', False): - if self.pipeline_config.enable_seq_reject=="sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_sum) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - elif self.pipeline_config.enable_seq_reject=="geometric": - log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_mean) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - # final_mask = keep_mask.float() - final_mask = keep_mask - # Step 4: Reweight policy loss - pg_loss = pg_loss * is_weight - # Step 5: Compute detailed stats over original_valid tokens - # Rejected mask - rejected_mask = original_valid & (~keep_mask) # [B, T] - # Clipped mask: only meaningful if IS weighting is active - if self.pipeline_config.infer_is_mode in ("token", "sequence"): - clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid - clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid - clipped_mask = clipped_low | clipped_high # [B, T] - else: - clipped_mask = torch.zeros_like(original_valid) # no clipping - # Compute fractions - def _compute_frac(mask_tensor): - return agg_loss( - loss_mat=mask_tensor.float(), - loss_mask=response_mask, - loss_agg_mode="token-mean" # force token-wise average - ).detach().item() - clip_frac = _compute_frac(clipped_mask) - reject_frac = _compute_frac(rejected_mask) - clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) - clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) - # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) - seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token - seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] - total_valid_seqs = seq_has_valid.sum().item() - rejected_seqs = seq_completely_rejected.sum().item() - seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 - - ### kl metric - inferkl_orig = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=response_mask, # ← original mask - kl_penalty="kl" - ) - inferkl_final = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=final_mask, # ← after rejection - kl_penalty="kl" - ) - inferkl_orig_agg = agg_loss( - loss_mat=inferkl_orig, - loss_mask=response_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - inferkl_final_agg = agg_loss( - loss_mat=inferkl_final, - loss_mask=final_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] - if valid_raw_is_weight.numel() > 0: - raw_is_mean = valid_raw_is_weight.mean().detach().item() - raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() - raw_is_min = valid_raw_is_weight.min().detach().item() - raw_is_max = valid_raw_is_weight.max().detach().item() - else: - # fallback if no valid tokens (rare edge case) - raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 - stats = { - "infer_correction/reject_frac": reject_frac, - "infer_correction/clip_frac": clip_frac, - "infer_correction/clip_and_reject_frac": clip_and_reject_frac, - "infer_correction/clip_or_reject_frac": clip_or_reject_frac, - "infer_correction/seq_reject_frac": seq_reject_frac, - "infer_correction/inferkl_orig": inferkl_orig_agg, - "infer_correction/inferkl_final": inferkl_final_agg, - "infer_correction/raw_is_mean": raw_is_mean, - "infer_correction/raw_is_std": raw_is_std, - "infer_correction/raw_is_min": raw_is_min, - "infer_correction/raw_is_max": raw_is_max, - } - return pg_loss, final_mask, stats + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): with Timer("do_checkpoint") as total_timer: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index b518acc2e..dd651b414 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl from roll.pipeline.rlvr.actor_worker import ActorWorker +from roll.utils.infer_correction import InferCorrectionHandler class ActorPGWorker(ActorWorker): @@ -30,11 +31,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() - infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", None) advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -77,13 +78,19 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -134,7 +141,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) - pg_metrics.updata(infer_stats) + + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 603ab4d7b..48e2fcf32 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -20,7 +21,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ref_log_probs = data.batch["ref_log_probs"] old_log_probs = data.batch["old_log_probs"] infer_log_probs = data.batch.get("infer_logprobs", None) - + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -70,13 +71,21 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + + + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -124,6 +133,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), diff --git a/roll/utils/infer_correction.py b/roll/utils/infer_correction.py new file mode 100644 index 000000000..215203081 --- /dev/null +++ b/roll/utils/infer_correction.py @@ -0,0 +1,380 @@ +from typing import Literal, Optional, Tuple, Dict, Any +import torch + +class StatsCollector: + """统一收集诊断指标的类""" + def __init__(self, prefix: str = "infer_correction"): + self.prefix = prefix + self.stats: Dict[str, Any] = {} + self.tensor_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} + + def add(self, name: str, value: Any): + """添加标量指标""" + self.stats[f"{self.prefix}/{name}"] = value.item() if torch.is_tensor(value) else value + + def add_tensor_stat(self, name: str, tensor: torch.Tensor, mask: torch.Tensor): + """添加张量统计指标(延迟计算)""" + self.tensor_stats[name] = (tensor, mask) + + def compute_tensor_stats(self): + """严格遵循原始代码的数据移动策略""" + for name, (tensor, mask) in self.tensor_stats.items(): + # 1. 确保在同一设备上 + if tensor.device != mask.device: + mask = mask.to(tensor.device) + + # 2. 直接在原始代码风格中计算:先筛选,再移动到CPU + mask=mask.bool() + valid = tensor[mask] + + # 3. 严格按照原始代码逻辑处理 + if valid.numel() > 0: + # 关键:先detach()再item(),确保在CPU上计算 + valid_cpu = valid.detach().cpu() + self.add(f"{name}_mean", valid_cpu.mean().item()) + self.add(f"{name}_std", valid_cpu.std(unbiased=False).item() if valid_cpu.numel() > 1 else 0.0) + self.add(f"{name}_min", valid_cpu.min().item()) + self.add(f"{name}_max", valid_cpu.max().item()) + else: + self.add(f"{name}_mean", 0.0) + self.add(f"{name}_std", 0.0) + self.add(f"{name}_min", 0.0) + self.add(f"{name}_max", 0.0) + + self.tensor_stats.clear() + + def get_metrics(self) -> Dict[str, float]: + """获取所有指标""" + return self.stats.copy() + +class InferCorrectionHandler: + """处理重要性采样校正和样本拒绝的核心类""" + def __init__(self, pipeline_config: "PPOConfig"): + self.pipeline_config = pipeline_config + self.stats = StatsCollector() + + def __call__( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: + """ + 主入口:执行重要性采样校正和样本拒绝 + + Args: + old_log_probs: 历史策略的log概率 [B, T] + infer_log_probs: 生成时策略的log概率 [B, T] + response_mask: 有效token掩码 [B, T] + pg_loss: 原始策略梯度损失 [B, T] + + Returns: + weighted_loss: 重加权后的损失 + final_mask: 最终保留的token掩码 + metrics: 诊断指标字典 + """ + # 1. 对齐形状 + infer_log_probs = self._align_shapes(old_log_probs, infer_log_probs, response_mask) + + # 2. 计算IS权重 + ratio, raw_is_weight, is_weight = self._compute_is_weights(old_log_probs, infer_log_probs, response_mask) + + # 3. 收集基础统计 + self._collect_base_stats(ratio, response_mask) + + # 4. 应用拒绝策略 + keep_mask = response_mask.clone() + keep_mask = self._apply_token_rejection(ratio, keep_mask) + keep_mask = self._apply_catastrophic_rejection(ratio, keep_mask, response_mask) + keep_mask = self._apply_sequence_rejection(ratio, keep_mask, response_mask) + + # 5. 计算拒绝统计 + self._collect_rejection_stats(ratio, raw_is_weight, keep_mask, response_mask) + + # 6. 重加权损失 + weighted_loss = pg_loss * is_weight + + # 7. 计算KL指标 + self._compute_kl_metrics(old_log_probs, infer_log_probs, keep_mask, response_mask) + + # 8. 批量计算张量统计 + self.stats.compute_tensor_stats() + + return weighted_loss, keep_mask, self.stats.get_metrics() + + def _align_shapes( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """对齐log概率张量形状""" + if infer_log_probs.shape[1] == old_log_probs.shape[1] + 1: + infer_log_probs = infer_log_probs[:, 1:] + + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, ( + f"Shape mismatch: old_log_probs {old_log_probs.shape}, " + f"infer_log_probs {infer_log_probs.shape}, " + f"response_mask {response_mask.shape}" + ) + return infer_log_probs + + def _compute_is_weights( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 计算重要性采样权重 + + Returns: + ratio: 原始重要性比率 [B, T] + raw_is_weight: 未裁剪的IS权重 [B, T] + is_weight: 裁剪后的IS权重 [B, T] + """ + log_ratio = old_log_probs - infer_log_probs + ratio = torch.exp(log_ratio) + + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS:使用序列总log-ratio + log_ratio_sum = self._masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) + seq_ratio = torch.exp(log_ratio_sum) + raw_is_weight = seq_ratio.expand_as(ratio) + # 收集序列级统计 + self.stats.add_tensor_stat("seq_ratio", seq_ratio.squeeze(-1), torch.ones_like(seq_ratio.squeeze(-1), dtype=torch.bool)) + else: # "None" or any other value + raw_is_weight = torch.ones_like(ratio) + + # 裁剪IS权重 + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + + return ratio, raw_is_weight, is_weight + + def _collect_base_stats(self, ratio: torch.Tensor, response_mask: torch.Tensor): + """收集基础统计指标""" + self.stats.add_tensor_stat("token_ratio", ratio, response_mask) + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 1. 裁剪比例统计(现有代码) + clipped_low = ratio <= self.pipeline_config.infer_is_threshold_min + clipped_high = ratio >= self.pipeline_config.infer_is_threshold_max + clipped = clipped_low | clipped_high + self.stats.add("token_clip_low_frac", self._agg_loss(clipped_low.float(), response_mask)) + self.stats.add("token_clip_high_frac", self._agg_loss(clipped_high.float(), response_mask)) + self.stats.add("token_clip_frac", self._agg_loss(clipped.float(), response_mask)) + + # 2. 添加缺失的:裁剪后权重的分布统计 + if self.pipeline_config.infer_is_mode == "token": + # 重新计算裁剪后的权重 + is_weight = ratio.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ) + # 添加缺失的统计 + self.stats.add_tensor_stat("token_is_weight", is_weight, response_mask) + + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS权重已在_compute_is_weights中添加 + pass + + def _apply_token_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor + ) -> torch.Tensor: + """应用token级拒绝策略""" + if not self.pipeline_config.enable_token_reject: + return keep_mask + + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + + # 更新掩码:丢弃被拒绝的token + new_keep_mask = keep_mask & (~token_reject) + + # 收集统计 + self.stats.add("token_reject_low_frac", self._agg_loss(ratio_too_low.float(), keep_mask)) + self.stats.add("token_reject_high_frac", self._agg_loss(ratio_too_high.float(), keep_mask)) + + return new_keep_mask + + def _apply_catastrophic_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用灾难性拒绝策略""" + if not self.pipeline_config.enable_catastrophic_reject: + return keep_mask + + # 识别灾难性token + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & response_mask + + # 检查哪些序列包含灾难性token + seq_has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + + # 更新掩码:丢弃包含灾难性token的整个序列 + new_keep_mask = keep_mask & (~seq_has_catastrophic) + + # 收集统计 + catastrophic_token_frac = self._agg_loss(catastrophic.float(), response_mask) + self.stats.add("catastrophic_token_frac", catastrophic_token_frac) + + # 计算包含灾难性token的序列比例 + seq_has_valid = response_mask.any(dim=-1) + seq_has_catastrophic_flat = catastrophic.any(dim=-1) & seq_has_valid + catastrophic_seq_frac = ( + seq_has_catastrophic_flat.sum().float() / seq_has_valid.sum().float() + if seq_has_valid.sum() > 0 else 0.0 + ) + self.stats.add("catastrophic_seq_frac", catastrophic_seq_frac) + + return new_keep_mask + + def _apply_sequence_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用序列级拒绝策略""" + if self.pipeline_config.enable_seq_reject in (None, "None", "none"): + return keep_mask + + # 计算序列级比率 + if self.pipeline_config.enable_seq_reject == "sequence": + log_ratio_agg = self._masked_sum(torch.log(ratio), response_mask, dim=-1) + elif self.pipeline_config.enable_seq_reject == "geometric": + log_ratio_agg = self._masked_mean(torch.log(ratio), response_mask, dim=-1) + else: + return keep_mask + + seq_ratio = torch.exp(log_ratio_agg) + + # 识别要拒绝的序列 + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + + # 更新掩码 + new_keep_mask = keep_mask & (~seq_reject) + + # 收集统计 + seq_has_valid = response_mask.any(dim=-1) + total_valid_seqs = seq_has_valid.sum().item() + + seq_reject_low = seq_too_low & seq_has_valid + seq_reject_high = seq_too_high & seq_has_valid + + seq_reject_low_frac = seq_reject_low.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + seq_reject_high_frac = seq_reject_high.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + self.stats.add("seq_reject_low_frac", seq_reject_low_frac) + self.stats.add("seq_reject_high_frac", seq_reject_high_frac) + + return new_keep_mask + + def _collect_rejection_stats( + self, + ratio: torch.Tensor, + raw_is_weight: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """收集拒绝相关的统计指标""" + + # 计算被拒绝的token + rejected_mask = response_mask & (~keep_mask) + self.stats.add("reject_frac", self._agg_loss(rejected_mask.float(), response_mask)) + + + # 仅在序列拒绝启用时计算序列级拒绝率 + if self.pipeline_config.enable_seq_reject not in (None, "None", "none"): + seq_has_valid = response_mask.any(dim=-1) + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + self.stats.add("seq_reject_frac", seq_reject_frac) + else: + # 未启用时显式设为0.0 + self.stats.add("seq_reject_frac", 0.0) + + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 使用已计算的rejected_mask + clipped_mask = ((raw_is_weight <= self.pipeline_config.infer_is_threshold_min) | + (raw_is_weight >= self.pipeline_config.infer_is_threshold_max)) & response_mask + + clip_and_reject_frac = self._agg_loss((clipped_mask & rejected_mask).float(), response_mask) + clip_or_reject_frac = self._agg_loss((clipped_mask | rejected_mask).float(), response_mask) + + self.stats.add("token_clip_and_reject_frac", clip_and_reject_frac) + self.stats.add("token_clip_or_reject_frac", clip_or_reject_frac) + else: + # 关键:为未启用IS的情况提供默认值 + self.stats.add("token_clip_and_reject_frac", 0.0) + self.stats.add("token_clip_or_reject_frac", 0.0) + + def _compute_kl_metrics( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """计算KL散度指标""" + # 原始KL(所有有效token) + inferkl_orig = self._compute_approx_kl(infer_log_probs, old_log_probs, response_mask, kl_penalty="kl") + inferkl_orig_agg = self._agg_loss(inferkl_orig, response_mask) + self.stats.add("inferkl", inferkl_orig_agg) + + # 拒绝后KL(仅保留的token) + inferkl_final = self._compute_approx_kl(infer_log_probs, old_log_probs, keep_mask, kl_penalty="kl") + inferkl_final_agg = self._agg_loss(inferkl_final, keep_mask) + self.stats.add("inferkl_reject", inferkl_final_agg) + + # --- 辅助方法(使用已有工具函数)--- + def _compute_approx_kl( + self, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: torch.Tensor, + kl_penalty: str = "kl" + ) -> torch.Tensor: + """使用已有的compute_approx_kl函数计算近似KL散度""" + from roll.utils.functionals import compute_approx_kl + return compute_approx_kl( + log_probs=log_probs, + log_probs_base=log_probs_base, + action_mask=action_mask, + kl_penalty=kl_penalty + ) + + def _agg_loss(self, loss_mat: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """使用已有的agg_loss函数聚合损失""" + from roll.utils.functionals import agg_loss + return agg_loss( + loss_mat=loss_mat, + loss_mask=loss_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ) + + def _masked_sum(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_sum函数在掩码区域求和""" + from roll.utils.functionals import masked_sum + return masked_sum(tensor, mask, dim=dim) + + def _masked_mean(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_mean函数在掩码区域计算均值""" + from roll.utils.functionals import masked_mean + return masked_mean(tensor, mask, dim=dim) \ No newline at end of file From a595ec3d5f09091b3b8c6f2196d056d3e2264f77 Mon Sep 17 00:00:00 2001 From: millioniron Date: Wed, 3 Dec 2025 14:08:18 +0800 Subject: [PATCH 3/7] new file: docs_roll/docs/User Guides/Configuration/infer_correction.md new file: examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file: examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file: examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file: examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh modified: roll/configs/base_config.py modified: roll/configs/generating_args.py modified: roll/distributed/scheduler/generate_scheduler.py modified: roll/pipeline/agentic/env_manager/step_env_manager.py modified: roll/pipeline/agentic/env_manager/traj_env_manager.py modified: roll/pipeline/base_worker.py modified: roll/pipeline/rlvr/actor_pg_worker.py modified: roll/pipeline/rlvr/actor_worker.py --- .../Configuration/infer_correction.md | 142 ++++++++++ .../agentic_webshop_infer_correction.yaml | 183 ++++++++++++ .../rlvr_infer_correction_config.yaml | 265 ++++++++++++++++++ .../run_agentic_pipeline_webshop.sh | 10 + .../run_rlvr_pipeline.sh | 5 + roll/configs/base_config.py | 56 ++++ roll/configs/generating_args.py | 4 + .../agentic/env_manager/step_env_manager.py | 1 + .../agentic/env_manager/traj_env_manager.py | 3 +- roll/pipeline/base_worker.py | 165 ++++++++++- roll/pipeline/rlvr/actor_pg_worker.py | 10 +- roll/pipeline/rlvr/actor_worker.py | 58 +--- 12 files changed, 846 insertions(+), 56 deletions(-) create mode 100644 docs_roll/docs/User Guides/Configuration/infer_correction.md create mode 100644 examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml create mode 100644 examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml create mode 100644 examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh create mode 100644 examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh diff --git a/docs_roll/docs/User Guides/Configuration/infer_correction.md b/docs_roll/docs/User Guides/Configuration/infer_correction.md new file mode 100644 index 000000000..ecdce6325 --- /dev/null +++ b/docs_roll/docs/User Guides/Configuration/infer_correction.md @@ -0,0 +1,142 @@ +# 训推差异修复 + + +## 简介 + +训推差异是由于RL训练过程中,训练器和生成器之间由于后端不同(vLLM vs SGLang vs FSDP vs Megatron),精度不同(FP8 vs FP16 vs BF16 vs FP32),形成了一种类似off-policy gap,会导致训练不稳定和策略崩溃。 + + +## 实现原理 + + +修复训推差异大致可分为两种方法(1)对训练器和生成器进行策略修正(2)使用infer_log_probs直接代替old_log_probs(trainer)进行PPO ratio计算。第二种方案比较直接,我们着重说明第一种方法。 + +### 对训练器和生成器进行策略修正 + +### IS权重矫正 +通过对训练器(old_log_probs)和生成器(infer_log_prob)之间进行重要性采样矫正,弥合训推差异。与off-policy算法类似,IS权重矫正可以区分token级别和sequence级别,只能选择一个。 + +### MASK过滤掉不符合条件的样本 +与IS权重修正不同的是,此方法对于超过阈值的样本直接进行mask遮掩,过滤掉不符合的样本。涉及的方法有(1)token级别:过滤掉不符合条件的token(2)灾难性token:过滤掉出现灾难性严重偏差的token的句子样本(3)sequence级别:对sequence进行IS计算,过滤掉不符合的句子样本(4)sequence级别,使用几何平均来计算IS权重,但指标也更为敏感 + + +## 关键参数配置 + +生成器是否返回infer_log_probs + +GeneratingArguments: + +```yaml +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: fp16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 +``` +当logprobs:大于0时,返回infer_log_probs + + +------- + +参数的配置在PPOConfig和中,关键配置参数如下: + +```yaml +infer_correction: true + +infer_is_mode: token # 可选 token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: true +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: true +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: sequence 可选None sequence geometric +infer_seq_mask_threshold_min: 0.1 +infer_seq_mask_threshold_max: 10 +``` + + +### infer_correction +- **含义**:控制是否启用训推差异修复机制。若启用,系统将使用 `infer_log_probs` 对策略梯度进行矫正。 +- **默认值**:`false` + +### infer_is_mode +- **含义**:指定重要性采样(IS)权重的计算粒度。 +- **可选值**: + - `"token"`:每个 token 独立计算 IS 权重 + - `"sequence"`:基于整个 response 序列的 log-ratio 求和后广播至所有 token + - `"none"`(或未设置):不应用 IS 加权,权重恒为 1 +- **默认值**:若未设置,默认为 `"token"` +- **注意**:不可同时使用多种模式,仅能选择其一。 + +### infer_is_threshold_min +- **含义**:IS 权重的下限阈值,用于裁剪过小的权重以控制方差。 +- **默认值**:`0.0` +- **建议**:通常保留为 `0.0`,以保持无偏性下界 + +### infer_is_threshold_max +- **含义**:IS 权重的上限阈值,防止极端大的权重主导梯度。 +- **默认值**:`2.0` +- **建议**:`"token"`级别推荐为 `1.5 ~ 5.0` `"sequence"`级别推荐为2.0 - 10.0 + +### enable_token_reject +- **含义**:是否启用 token 级别的样本拒绝机制。 +- **默认值**:`false` +- **作用**:结合 `infer_token_mask_threshold_min/max`,mask 掉 IS ratio 超出合法区间的 token。 + +### infer_token_mask_threshold_min +- **含义**:token 级 IS ratio(`old_log_probs / infer_log_probs` 的指数)的下限。 +- **默认值**:`0.0` +- **典型值**:`0.0`通常可设为1/max + +### infer_token_mask_threshold_max +- **含义**:遮掩token 级 IS ratio 的上限。 +- **默认值**:`2.0` +- **典型范围**:`1.5 ~ 5.0` + +### enable_catastrophic_reject +- **含义**:是否启用“灾难性偏差”检测并拒绝整句样本。 +- **默认值**:`false` +- **触发条件**:只要序列中存在任一 token 满足 `ratio < infer_catastrophic_threshold`,则整句被 mask。 + +### infer_catastrophic_threshold +- **含义**:灾难性拒绝的 ratio 下限阈值。 +- **默认值**:`1e-4` +- **解释**:当 `infer_log_probs` 远大于 `old_log_probs`(即生成器过于“自信”),导致 `ratio = exp(old - infer)` 极小 + +### enable_seq_reject +- **含义**:是否启用序列级别的拒绝机制,以及使用何种聚合方式。 +- **可选值**: + - `null` / `false`:禁用 + - `"sequence"`:使用 log-ratio **求和** 计算序列 IS ratio + - `"geometric"`:使用 log-ratio **平均**(等价于几何平均概率)计算序列 IS ratio +- **默认值**:`null` + +### infer_seq_mask_threshold_min +- **含义**:遮掩序列级 IS ratio 的下限。 +- **默认值**:`0.1` +- **典型值**:通常可设为1/max,当使用`"geometric"`时,最好强制设为1/max + + +### infer_seq_mask_threshold_max +- **含义**:遮掩序列级 IS ratio 的上限。 +- **默认值**:`10.0` +- **典型范围**:当使用`"sequence"`时,推荐`2.0 ~ 10.0`,但随着长度增加可适当放宽。当使用`"geometric"`时,推荐设置为1.0001 - 1.001 + + + +## 使用建议 + +1. 通常情况下,old_log_prob << infer_log_porb, 阈值的下限就比较重要了。并不建议使用sequence级别的IS或MASK + diff --git a/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file mode 100644 index 000000000..78fefee47 --- /dev/null +++ b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml @@ -0,0 +1,183 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_webshop_infer_correction" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_webshop +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +#track_with: swanlab +#tracker_kwargs: +# login_kwargs: +# api_key: your_api_key +# project: roll-agentic +# logdir: debug +# experiment_name: ${exp_name} +# tags: +# - roll +# - agentic +# - debug + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_webshop + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 64 +val_batch_size: 64 +sequence_length: 8192 + +reward_clip: 20 +advantage_clip: 0.2 # 0.1-0.3 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +max_grad_norm: 1.0 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +# infer correction + +infer_correction: true + +infer_is_mode: token #token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + + +enable_seq_reject: None # None sequence geometric +infer_seq_mask_threshold_min: 0.999 +infer_seq_mask_threshold_max: 1.001 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 8 + warmup_steps: 10 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + max_grad_norm: ${max_grad_norm} + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 1024 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + logprobs: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: sglang + strategy_config: + mem_fraction_static: 0.85 + load_format: auto + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.05 + num_env_groups: 8 + group_size: 8 + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + num_env_groups: 64 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [64] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +custom_envs: + WebShopEnv: + ${custom_env.WebShopEnv} \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file mode 100644 index 000000000..e2cbac8a7 --- /dev/null +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -0,0 +1,265 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-infer-correction-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + + +# infer correction +infer_correction: true + +infer_is_mode: token +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_rs_threshold_min: 0.0 +infer_token_rs_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: None + +# enable_seq_reject: sequence +# infer_seq_rs_threshold_min: 0.1 +# infer_seq_rs_threshold_max: 10 + +# enable_seq_reject: geometric +# infer_seq_rs_threshold_min: 0.999 +# infer_seq_rs_threshold_max: 1.001 + +validation: + data_args: + template: qwen2.5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2.5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 + data_args: + template: qwen2.5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2.5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + world_size: 8 + infer_batch_size: 1 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2.5 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file mode 100644 index 000000000..ff1a59402 --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Run `git submodule update --init --recursive` to init submodules before run this script. +set +x + + +pip install -r third_party/webshop-minimal/requirements.txt --trusted-host mirrors.aliyun.com --index-url https://mirrors.aliyun.com/pypi/simple/ +python -m spacy download en_core_web_sm + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_agentic_pipeline.py --config_path $CONFIG_PATH --config_name agentic_webshop_infer_correction diff --git a/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh new file mode 100644 index 000000000..6874f4f5e --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_infer_correction_config \ No newline at end of file diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index aa1da4c27..8dc75d3f3 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -395,12 +395,68 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) + enable_reference: bool = field( default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} ) enable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) force_disable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs_recompute"}) + # trainer&rollout mismatch + infer_correction: bool = field( + default=False, + metadata={"help": "Whether to apply importance sampling correction during inference."} + ) + infer_is_mode: Literal["token", "sequence", "none"] = field( + default="token", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + ) + # Clipping thresholds (used in IS weighting) + infer_is_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum threshold for IS weight clipping. Recommended 0.0 for unbiased estimation."} + ) + infer_is_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum threshold for IS weight clipping."} + ) + # Token-level rejection + enable_token_reject: bool = field( + default=False, + metadata={"help": "Enable token-level rejection based on IS ratio thresholds."} + ) + infer_token_mask_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum IS ratio threshold for token rejection."} + ) + infer_token_mask_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum IS ratio threshold for token rejection."} + ) + # Catastrophic rejection (reject entire sequence if any token ratio is too small) + enable_catastrophic_reject: bool = field( + default=False, + metadata={"help": "Enable catastrophic rejection: reject entire sequence if any valid token has IS ratio below threshold."} + ) + infer_catastrophic_threshold: float = field( + default=1e-4, + metadata={"help": "Threshold below which a token triggers catastrophic rejection of its sequence."} + ) + # Sequence-level rejection + enable_seq_reject: Optional[Literal["sequence", "geometric",'None']] = field( + default=None, + metadata={"help": "Enable sequence-level rejection: 'sequence' uses sum of log-ratios, 'geometric' uses mean. None disables."} + ) + infer_seq_mask_threshold_min: float = field( + default=0.1, + metadata={"help": "Minimum IS ratio threshold for sequence rejection."} + ) + infer_seq_mask_threshold_max: float = field( + default=10.0, + metadata={"help": "Maximum IS ratio threshold for sequence rejection (typically larger than token-level)."} + ) + + def __post_init__(self): super().__post_init__() diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 68cf88d17..0822614ba 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -58,6 +58,10 @@ class GeneratingArguments: default=None, metadata={"help": "Whether to include the stop strings in output text."}, ) + logprobs: Optional[int] = field( + default=None, + metadata={"help": "Whether return infer log-prob."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 4348605a3..26e95106d 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -105,6 +105,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, + "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 88ab15d91..80f15113f 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig from tensordict import TensorDict from transformers import PreTrainedTokenizer - +import json from roll.pipeline.agentic.llm_proxy import create_llm_proxy, BaseLLMProxy from roll.pipeline.agentic.env_manager.base_env_manager import RolloutCache, BaseEnvManager from roll.utils.env_action_limiter import get_global_limiter @@ -308,6 +308,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): infer_logprobs.extend([0] * len(items["prompt_ids"]) + items["infer_logprobs"]) input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) + infer_logprobs = torch.tensor(infer_logprobs, dtype=torch.float).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index d5c84c120..7148752a1 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -2,7 +2,7 @@ import threading import time from typing import Union, Optional, Dict - +import numpy import ray import torch from codetiming import Timer @@ -25,6 +25,7 @@ postprocess_generate, GenerateRequestType, agg_loss, + masked_sum ) from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType @@ -307,8 +308,10 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] + infer_log_probs = data.batch.get("infer_logprobs", None) log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] @@ -325,12 +328,18 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + pg_loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=pg_loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, - kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -372,9 +381,153 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), } + pg_metrics.update(infer_stats) return total_loss, pg_metrics - + + + def infer_correction( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + """ + 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 + 返回更新后的 pg_loss、mask 和详细统计信息。 + """ + # Step 0: Shape alignment + if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: + infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ + f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" + # Step 1: Compute log-ratio and ratio + log_ratio = old_log_probs - infer_log_probs # [B, T] + ratio = torch.exp(log_ratio) # [B, T] + # Step 2: Apply IS weighting strategy (optional) + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] + raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] + elif self.pipeline_config.infer_is_mode in (None, "none", ""): + raw_is_weight = torch.ones_like(ratio) + else: + raw_is_weight = torch.ones_like(ratio) + # Clamp to get final is_weight (used for loss) + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + # Step 3: Build rejection mask + original_valid = response_mask > 0.5 # [B, T], bool + keep_mask = original_valid.clone() + # (a) Token-level ratio reject + if getattr(self.pipeline_config, 'enable_token_reject', False): + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + keep_mask = keep_mask & (~token_reject) + # (b) Catastrophic reject + if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid + has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + keep_mask = keep_mask & (~has_catastrophic) + # (c) Sequence-level reject + if getattr(self.pipeline_config, 'enable_seq_reject', False): + if self.pipeline_config.enable_seq_reject=="sequence": + log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_sum) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + elif self.pipeline_config.enable_seq_reject=="geometric": + log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] + seq_ratio = torch.exp(log_ratio_mean) # [B] + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + keep_mask = keep_mask & (~seq_reject) + # final_mask = keep_mask.float() + final_mask = keep_mask + # Step 4: Reweight policy loss + pg_loss = pg_loss * is_weight + # Step 5: Compute detailed stats over original_valid tokens + # Rejected mask + rejected_mask = original_valid & (~keep_mask) # [B, T] + # Clipped mask: only meaningful if IS weighting is active + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid + clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid + clipped_mask = clipped_low | clipped_high # [B, T] + else: + clipped_mask = torch.zeros_like(original_valid) # no clipping + # Compute fractions + def _compute_frac(mask_tensor): + return agg_loss( + loss_mat=mask_tensor.float(), + loss_mask=response_mask, + loss_agg_mode="token-mean" # force token-wise average + ).detach().item() + clip_frac = _compute_frac(clipped_mask) + reject_frac = _compute_frac(rejected_mask) + clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) + clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) + # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) + seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + ### kl metric + inferkl_orig = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=response_mask, # ← original mask + kl_penalty="kl" + ) + inferkl_final = compute_approx_kl( + log_probs=infer_log_probs, + log_probs_base=old_log_probs, + action_mask=final_mask, # ← after rejection + kl_penalty="kl" + ) + inferkl_orig_agg = agg_loss( + loss_mat=inferkl_orig, + loss_mask=response_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + inferkl_final_agg = agg_loss( + loss_mat=inferkl_final, + loss_mask=final_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ).detach().item() + valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] + if valid_raw_is_weight.numel() > 0: + raw_is_mean = valid_raw_is_weight.mean().detach().item() + raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() + raw_is_min = valid_raw_is_weight.min().detach().item() + raw_is_max = valid_raw_is_weight.max().detach().item() + else: + # fallback if no valid tokens (rare edge case) + raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 + stats = { + "infer_correction/reject_frac": reject_frac, + "infer_correction/clip_frac": clip_frac, + "infer_correction/clip_and_reject_frac": clip_and_reject_frac, + "infer_correction/clip_or_reject_frac": clip_or_reject_frac, + "infer_correction/seq_reject_frac": seq_reject_frac, + "infer_correction/inferkl_orig": inferkl_orig_agg, + "infer_correction/inferkl_final": inferkl_final_agg, + "infer_correction/raw_is_mean": raw_is_mean, + "infer_correction/raw_is_std": raw_is_std, + "infer_correction/raw_is_min": raw_is_min, + "infer_correction/raw_is_max": raw_is_max, + } + return pg_loss, final_mask, stats @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): if self.worker_config.offload_nccl: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 477438595..970d13b4d 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -30,6 +30,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] @@ -76,7 +77,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -127,6 +134,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) + pg_metrics.updata(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 19d0c66de..7e0ad08c1 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -17,6 +17,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -53,34 +54,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl" ) - train_infer_ratio = (old_log_probs - infer_log_probs).exp() - train_infer_diff = old_log_probs.exp() - infer_log_probs.exp() - train_infer_ratio_seq = masked_mean(old_log_probs - infer_log_probs, response_mask, dim=-1).exp().unsqueeze(-1).expand_as(train_infer_ratio) - train_infer_diff_seq = masked_mean(old_log_probs.exp() - infer_log_probs.exp(), response_mask, dim=-1).unsqueeze(-1).expand_as(train_infer_diff) - - train_infer_ratio_mask_mean = 1.0 - train_infer_diff_mask_mean = 1.0 - train_infer_ratio_seq_mask_mean = 1.0 - train_infer_diff_seq_mask_mean = 1.0 - - if self.pipeline_config.train_infer_ratio_mask: - train_infer_ratio_mask = (train_infer_ratio <= self.pipeline_config.train_infer_ratio_threshold_high).float() * (train_infer_ratio >= self.pipeline_config.train_infer_ratio_threshold_low).float() - train_infer_ratio_mask_mean = masked_mean(train_infer_ratio_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_mask - if self.pipeline_config.train_infer_diff_mask: - train_infer_diff_mask = (train_infer_diff <= self.pipeline_config.train_infer_diff_threshold_high).float() * (train_infer_diff >= self.pipeline_config.train_infer_diff_threshold_low).float() - train_infer_diff_mask_mean = masked_mean(train_infer_diff_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_mask - - if self.pipeline_config.train_infer_ratio_seq_mask: - train_infer_ratio_seq_mask = (train_infer_ratio_seq <= self.pipeline_config.train_infer_ratio_seq_threshold_high).float() * (train_infer_ratio_seq >= self.pipeline_config.train_infer_ratio_seq_threshold_low).float() - train_infer_ratio_seq_mask_mean = masked_mean(train_infer_ratio_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_seq_mask - if self.pipeline_config.train_infer_diff_seq_mask: - train_infer_diff_seq_mask = (train_infer_diff_seq <= self.pipeline_config.train_infer_diff_seq_threshold_high).float() * (train_infer_diff_seq >= self.pipeline_config.train_infer_diff_seq_threshold_low).float() - train_infer_diff_seq_mask_mean = masked_mean(train_infer_diff_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_seq_mask - if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() elif self.pipeline_config.importance_sampling == "seq": @@ -98,11 +71,13 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - - if self.pipeline_config.use_rollout_importance_sampling_ratio: - rollout_importance_sampling_clip = (train_infer_ratio > self.pipeline_config.rollout_importance_sampling_ratio_upper_bound).float() - loss = train_infer_ratio.clamp(0, self.pipeline_config.rollout_importance_sampling_ratio_upper_bound) * loss - + + if infer_log_probs is not None and self.pipeline_config.infer_correction: + loss, infer_response_mask, infer_stats=self.infer_correction( + old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, + response_mask=response_mask,pg_loss=loss) + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -149,16 +124,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_scale=loss_scale) total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() - - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_ratio_mask_mean": train_infer_ratio_mask_mean, - "actor/train_infer_diff_mask_mean": train_infer_diff_mask_mean, - "actor/train_infer_ratio_seq_mask_mean": train_infer_ratio_seq_mask_mean, - "actor/train_infer_diff_seq_mask_mean": train_infer_diff_seq_mask_mean, - } - + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -170,9 +136,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode, loss_scale=loss_scale).detach().item(), } - if self.pipeline_config.use_rollout_importance_sampling_ratio: - loss_metric["actor/rollout_importance_sampling_clip"] = rollout_importance_sampling_clip.mean().detach().item() - pg_metrics = { "actor/pg_loss": original_pg_loss.detach().item(), "actor/weighted_pg_loss": weighted_pg_loss.detach().item(), @@ -190,9 +153,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/sample_weights_max": sample_weights.max().detach().item(), **metrics, **loss_metric, - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) return total_loss, pg_metrics def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): From e98c4cea7a5d56649711dc5e08db1078bd001fd8 Mon Sep 17 00:00:00 2001 From: millioniron Date: Sun, 7 Dec 2025 14:43:05 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E9=87=8D=E6=96=B0=E4=BF=AE=E8=AE=A2?= =?UTF-8?q?=E4=BA=86=E6=95=B4=E4=B8=AA=E7=9A=84=E6=8E=92=E7=89=88=EF=BC=8C?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=E5=87=BA=E4=BA=86=E4=B8=80=E4=B8=AA=E7=B1=BB?= =?UTF-8?q?=EF=BC=8C=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E6=9B=B4=E5=8A=A0?= =?UTF-8?q?=E8=87=AA=E7=94=B1=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rlvr_infer_correction_config.yaml | 12 +- roll/configs/base_config.py | 6 +- roll/configs/generating_args.py | 2 +- roll/pipeline/base_worker.py | 169 +------- roll/pipeline/rlvr/actor_pg_worker.py | 23 +- roll/pipeline/rlvr/actor_worker.py | 22 +- roll/utils/infer_correction.py | 380 ++++++++++++++++++ 7 files changed, 439 insertions(+), 175 deletions(-) create mode 100644 roll/utils/infer_correction.py diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml index e2cbac8a7..1bbe3c8e8 100644 --- a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -80,8 +80,8 @@ infer_is_threshold_min: 0.0 infer_is_threshold_max: 2.0 # 1.5~5.0 enable_token_reject: false -infer_token_rs_threshold_min: 0.0 -infer_token_rs_threshold_max: 2.0 # 2~10 +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 enable_catastrophic_reject: false infer_catastrophic_threshold: 1e-4 @@ -89,12 +89,12 @@ infer_catastrophic_threshold: 1e-4 enable_seq_reject: None # enable_seq_reject: sequence -# infer_seq_rs_threshold_min: 0.1 -# infer_seq_rs_threshold_max: 10 +# infer_seq_mask_threshold_min: 0.1 +# infer_seq_mask_threshold_max: 10 # enable_seq_reject: geometric -# infer_seq_rs_threshold_min: 0.999 -# infer_seq_rs_threshold_max: 1.001 +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 validation: data_args: diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index 8dc75d3f3..a6ea37f09 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -407,9 +407,9 @@ class PPOConfig(BaseConfig): default=False, metadata={"help": "Whether to apply importance sampling correction during inference."} ) - infer_is_mode: Literal["token", "sequence", "none"] = field( - default="token", - metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'none' (no IS weighting)."} + infer_is_mode: Literal["token", "sequence", "None"] = field( + default="None", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'None' (no IS weighting)."} ) # Clipping thresholds (used in IS weighting) infer_is_threshold_min: float = field( diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 0822614ba..5d1f94f95 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -59,7 +59,7 @@ class GeneratingArguments: metadata={"help": "Whether to include the stop strings in output text."}, ) logprobs: Optional[int] = field( - default=None, + default=0, metadata={"help": "Whether return infer log-prob."}, ) diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index 7148752a1..e7cb65389 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -25,10 +25,11 @@ postprocess_generate, GenerateRequestType, agg_loss, - masked_sum + masked_sum, ) from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType +from roll.utils.infer_correction import InferCorrectionHandler from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -309,10 +310,10 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) + ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] infer_log_probs = data.batch.get("infer_logprobs", None) - log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) @@ -328,19 +329,26 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - pg_loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=pg_loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, + kl_penalty="k3") kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" ) @@ -384,150 +392,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_metrics.update(infer_stats) return total_loss, pg_metrics - - - def infer_correction( - self, - old_log_probs: torch.Tensor, - infer_log_probs: torch.Tensor, - response_mask: torch.Tensor, - pg_loss: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, dict]: - """ - 处理 importance sampling ratio,支持 IS 裁剪与多种 reject 策略。 - 返回更新后的 pg_loss、mask 和详细统计信息。 - """ - # Step 0: Shape alignment - if infer_log_probs.shape[1] == old_log_probs.shape[1]+1: - infer_log_probs = infer_log_probs[:, 1:] # align with response_mask[:, 1:] - assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, \ - f"Shape mismatch: {old_log_probs.shape}, {infer_log_probs.shape}, {response_mask.shape}" - # Step 1: Compute log-ratio and ratio - log_ratio = old_log_probs - infer_log_probs # [B, T] - ratio = torch.exp(log_ratio) # [B, T] - # Step 2: Apply IS weighting strategy (optional) - if self.pipeline_config.infer_is_mode == "token": - raw_is_weight = ratio - elif self.pipeline_config.infer_is_mode == "sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) # [B, 1] - raw_is_weight = torch.exp(log_ratio_sum).expand_as(ratio) # [B, T] - elif self.pipeline_config.infer_is_mode in (None, "none", ""): - raw_is_weight = torch.ones_like(ratio) - else: - raw_is_weight = torch.ones_like(ratio) - # Clamp to get final is_weight (used for loss) - is_weight = raw_is_weight.clamp( - min=self.pipeline_config.infer_is_threshold_min, - max=self.pipeline_config.infer_is_threshold_max - ).detach() - # Step 3: Build rejection mask - original_valid = response_mask > 0.5 # [B, T], bool - keep_mask = original_valid.clone() - # (a) Token-level ratio reject - if getattr(self.pipeline_config, 'enable_token_reject', False): - ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max - ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min - token_reject = ratio_too_high | ratio_too_low - keep_mask = keep_mask & (~token_reject) - # (b) Catastrophic reject - if getattr(self.pipeline_config, 'enable_catastrophic_reject', False): - catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & original_valid - has_catastrophic = catastrophic.any(dim=-1, keepdim=True) - keep_mask = keep_mask & (~has_catastrophic) - # (c) Sequence-level reject - if getattr(self.pipeline_config, 'enable_seq_reject', False): - if self.pipeline_config.enable_seq_reject=="sequence": - log_ratio_sum = masked_sum(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_sum) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - elif self.pipeline_config.enable_seq_reject=="geometric": - log_ratio_mean = masked_mean(log_ratio, response_mask, dim=-1) # [B] - seq_ratio = torch.exp(log_ratio_mean) # [B] - seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max - seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min - seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) - keep_mask = keep_mask & (~seq_reject) - # final_mask = keep_mask.float() - final_mask = keep_mask - # Step 4: Reweight policy loss - pg_loss = pg_loss * is_weight - # Step 5: Compute detailed stats over original_valid tokens - # Rejected mask - rejected_mask = original_valid & (~keep_mask) # [B, T] - # Clipped mask: only meaningful if IS weighting is active - if self.pipeline_config.infer_is_mode in ("token", "sequence"): - clipped_low = (raw_is_weight <= self.pipeline_config.infer_is_threshold_min) & original_valid - clipped_high = (raw_is_weight >= self.pipeline_config.infer_is_threshold_max) & original_valid - clipped_mask = clipped_low | clipped_high # [B, T] - else: - clipped_mask = torch.zeros_like(original_valid) # no clipping - # Compute fractions - def _compute_frac(mask_tensor): - return agg_loss( - loss_mat=mask_tensor.float(), - loss_mask=response_mask, - loss_agg_mode="token-mean" # force token-wise average - ).detach().item() - clip_frac = _compute_frac(clipped_mask) - reject_frac = _compute_frac(rejected_mask) - clip_and_reject_frac = _compute_frac(clipped_mask & rejected_mask) - clip_or_reject_frac = _compute_frac(clipped_mask | rejected_mask) - # A sequence is rejected if NO token is kept (i.e., all final_mask == 0 for that seq) - seq_has_valid = original_valid.any(dim=-1) # [B], bool: seq has >=1 valid token - seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid # [B] - total_valid_seqs = seq_has_valid.sum().item() - rejected_seqs = seq_completely_rejected.sum().item() - seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 - - ### kl metric - inferkl_orig = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=response_mask, # ← original mask - kl_penalty="kl" - ) - inferkl_final = compute_approx_kl( - log_probs=infer_log_probs, - log_probs_base=old_log_probs, - action_mask=final_mask, # ← after rejection - kl_penalty="kl" - ) - inferkl_orig_agg = agg_loss( - loss_mat=inferkl_orig, - loss_mask=response_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - inferkl_final_agg = agg_loss( - loss_mat=inferkl_final, - loss_mask=final_mask, - loss_agg_mode=self.pipeline_config.loss_agg_mode - ).detach().item() - valid_raw_is_weight = raw_is_weight[original_valid] # [N_valid_tokens,] - if valid_raw_is_weight.numel() > 0: - raw_is_mean = valid_raw_is_weight.mean().detach().item() - raw_is_std = valid_raw_is_weight.std(unbiased=False).detach().item() - raw_is_min = valid_raw_is_weight.min().detach().item() - raw_is_max = valid_raw_is_weight.max().detach().item() - else: - # fallback if no valid tokens (rare edge case) - raw_is_mean = raw_is_std = raw_is_min = raw_is_max = 0.0 - stats = { - "infer_correction/reject_frac": reject_frac, - "infer_correction/clip_frac": clip_frac, - "infer_correction/clip_and_reject_frac": clip_and_reject_frac, - "infer_correction/clip_or_reject_frac": clip_or_reject_frac, - "infer_correction/seq_reject_frac": seq_reject_frac, - "infer_correction/inferkl_orig": inferkl_orig_agg, - "infer_correction/inferkl_final": inferkl_final_agg, - "infer_correction/raw_is_mean": raw_is_mean, - "infer_correction/raw_is_std": raw_is_std, - "infer_correction/raw_is_min": raw_is_min, - "infer_correction/raw_is_max": raw_is_max, - } - return pg_loss, final_mask, stats + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def do_checkpoint(self, global_step): if self.worker_config.offload_nccl: diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 970d13b4d..e83ed0a9b 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl from roll.pipeline.rlvr.actor_worker import ActorWorker +from roll.utils.infer_correction import InferCorrectionHandler class ActorPGWorker(ActorWorker): @@ -30,10 +31,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() - infer_log_probs = data.batch.get("infer_logprobs", None) final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", None) advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -77,13 +79,19 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): pg_loss = self._compute_kimi15_loss(ratio, log_probs, old_log_probs, advantages) else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -134,7 +142,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) - pg_metrics.updata(infer_stats) + + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 7e0ad08c1..da34d3a38 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -17,7 +18,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] - + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -71,13 +72,21 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) - + + infer_stats = {} if infer_log_probs is not None and self.pipeline_config.infer_correction: - loss, infer_response_mask, infer_stats=self.infer_correction( - old_log_probs=old_log_probs, infer_log_probs=infer_log_probs, - response_mask=response_mask,pg_loss=loss) + correction_handler = InferCorrectionHandler(self.pipeline_config) + loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=loss + ) + # 更新最终掩码 final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - + + + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -124,6 +133,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_scale=loss_scale) total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), diff --git a/roll/utils/infer_correction.py b/roll/utils/infer_correction.py new file mode 100644 index 000000000..215203081 --- /dev/null +++ b/roll/utils/infer_correction.py @@ -0,0 +1,380 @@ +from typing import Literal, Optional, Tuple, Dict, Any +import torch + +class StatsCollector: + """统一收集诊断指标的类""" + def __init__(self, prefix: str = "infer_correction"): + self.prefix = prefix + self.stats: Dict[str, Any] = {} + self.tensor_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} + + def add(self, name: str, value: Any): + """添加标量指标""" + self.stats[f"{self.prefix}/{name}"] = value.item() if torch.is_tensor(value) else value + + def add_tensor_stat(self, name: str, tensor: torch.Tensor, mask: torch.Tensor): + """添加张量统计指标(延迟计算)""" + self.tensor_stats[name] = (tensor, mask) + + def compute_tensor_stats(self): + """严格遵循原始代码的数据移动策略""" + for name, (tensor, mask) in self.tensor_stats.items(): + # 1. 确保在同一设备上 + if tensor.device != mask.device: + mask = mask.to(tensor.device) + + # 2. 直接在原始代码风格中计算:先筛选,再移动到CPU + mask=mask.bool() + valid = tensor[mask] + + # 3. 严格按照原始代码逻辑处理 + if valid.numel() > 0: + # 关键:先detach()再item(),确保在CPU上计算 + valid_cpu = valid.detach().cpu() + self.add(f"{name}_mean", valid_cpu.mean().item()) + self.add(f"{name}_std", valid_cpu.std(unbiased=False).item() if valid_cpu.numel() > 1 else 0.0) + self.add(f"{name}_min", valid_cpu.min().item()) + self.add(f"{name}_max", valid_cpu.max().item()) + else: + self.add(f"{name}_mean", 0.0) + self.add(f"{name}_std", 0.0) + self.add(f"{name}_min", 0.0) + self.add(f"{name}_max", 0.0) + + self.tensor_stats.clear() + + def get_metrics(self) -> Dict[str, float]: + """获取所有指标""" + return self.stats.copy() + +class InferCorrectionHandler: + """处理重要性采样校正和样本拒绝的核心类""" + def __init__(self, pipeline_config: "PPOConfig"): + self.pipeline_config = pipeline_config + self.stats = StatsCollector() + + def __call__( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: + """ + 主入口:执行重要性采样校正和样本拒绝 + + Args: + old_log_probs: 历史策略的log概率 [B, T] + infer_log_probs: 生成时策略的log概率 [B, T] + response_mask: 有效token掩码 [B, T] + pg_loss: 原始策略梯度损失 [B, T] + + Returns: + weighted_loss: 重加权后的损失 + final_mask: 最终保留的token掩码 + metrics: 诊断指标字典 + """ + # 1. 对齐形状 + infer_log_probs = self._align_shapes(old_log_probs, infer_log_probs, response_mask) + + # 2. 计算IS权重 + ratio, raw_is_weight, is_weight = self._compute_is_weights(old_log_probs, infer_log_probs, response_mask) + + # 3. 收集基础统计 + self._collect_base_stats(ratio, response_mask) + + # 4. 应用拒绝策略 + keep_mask = response_mask.clone() + keep_mask = self._apply_token_rejection(ratio, keep_mask) + keep_mask = self._apply_catastrophic_rejection(ratio, keep_mask, response_mask) + keep_mask = self._apply_sequence_rejection(ratio, keep_mask, response_mask) + + # 5. 计算拒绝统计 + self._collect_rejection_stats(ratio, raw_is_weight, keep_mask, response_mask) + + # 6. 重加权损失 + weighted_loss = pg_loss * is_weight + + # 7. 计算KL指标 + self._compute_kl_metrics(old_log_probs, infer_log_probs, keep_mask, response_mask) + + # 8. 批量计算张量统计 + self.stats.compute_tensor_stats() + + return weighted_loss, keep_mask, self.stats.get_metrics() + + def _align_shapes( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """对齐log概率张量形状""" + if infer_log_probs.shape[1] == old_log_probs.shape[1] + 1: + infer_log_probs = infer_log_probs[:, 1:] + + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, ( + f"Shape mismatch: old_log_probs {old_log_probs.shape}, " + f"infer_log_probs {infer_log_probs.shape}, " + f"response_mask {response_mask.shape}" + ) + return infer_log_probs + + def _compute_is_weights( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 计算重要性采样权重 + + Returns: + ratio: 原始重要性比率 [B, T] + raw_is_weight: 未裁剪的IS权重 [B, T] + is_weight: 裁剪后的IS权重 [B, T] + """ + log_ratio = old_log_probs - infer_log_probs + ratio = torch.exp(log_ratio) + + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS:使用序列总log-ratio + log_ratio_sum = self._masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) + seq_ratio = torch.exp(log_ratio_sum) + raw_is_weight = seq_ratio.expand_as(ratio) + # 收集序列级统计 + self.stats.add_tensor_stat("seq_ratio", seq_ratio.squeeze(-1), torch.ones_like(seq_ratio.squeeze(-1), dtype=torch.bool)) + else: # "None" or any other value + raw_is_weight = torch.ones_like(ratio) + + # 裁剪IS权重 + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + + return ratio, raw_is_weight, is_weight + + def _collect_base_stats(self, ratio: torch.Tensor, response_mask: torch.Tensor): + """收集基础统计指标""" + self.stats.add_tensor_stat("token_ratio", ratio, response_mask) + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 1. 裁剪比例统计(现有代码) + clipped_low = ratio <= self.pipeline_config.infer_is_threshold_min + clipped_high = ratio >= self.pipeline_config.infer_is_threshold_max + clipped = clipped_low | clipped_high + self.stats.add("token_clip_low_frac", self._agg_loss(clipped_low.float(), response_mask)) + self.stats.add("token_clip_high_frac", self._agg_loss(clipped_high.float(), response_mask)) + self.stats.add("token_clip_frac", self._agg_loss(clipped.float(), response_mask)) + + # 2. 添加缺失的:裁剪后权重的分布统计 + if self.pipeline_config.infer_is_mode == "token": + # 重新计算裁剪后的权重 + is_weight = ratio.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ) + # 添加缺失的统计 + self.stats.add_tensor_stat("token_is_weight", is_weight, response_mask) + + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS权重已在_compute_is_weights中添加 + pass + + def _apply_token_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor + ) -> torch.Tensor: + """应用token级拒绝策略""" + if not self.pipeline_config.enable_token_reject: + return keep_mask + + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + + # 更新掩码:丢弃被拒绝的token + new_keep_mask = keep_mask & (~token_reject) + + # 收集统计 + self.stats.add("token_reject_low_frac", self._agg_loss(ratio_too_low.float(), keep_mask)) + self.stats.add("token_reject_high_frac", self._agg_loss(ratio_too_high.float(), keep_mask)) + + return new_keep_mask + + def _apply_catastrophic_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用灾难性拒绝策略""" + if not self.pipeline_config.enable_catastrophic_reject: + return keep_mask + + # 识别灾难性token + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & response_mask + + # 检查哪些序列包含灾难性token + seq_has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + + # 更新掩码:丢弃包含灾难性token的整个序列 + new_keep_mask = keep_mask & (~seq_has_catastrophic) + + # 收集统计 + catastrophic_token_frac = self._agg_loss(catastrophic.float(), response_mask) + self.stats.add("catastrophic_token_frac", catastrophic_token_frac) + + # 计算包含灾难性token的序列比例 + seq_has_valid = response_mask.any(dim=-1) + seq_has_catastrophic_flat = catastrophic.any(dim=-1) & seq_has_valid + catastrophic_seq_frac = ( + seq_has_catastrophic_flat.sum().float() / seq_has_valid.sum().float() + if seq_has_valid.sum() > 0 else 0.0 + ) + self.stats.add("catastrophic_seq_frac", catastrophic_seq_frac) + + return new_keep_mask + + def _apply_sequence_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用序列级拒绝策略""" + if self.pipeline_config.enable_seq_reject in (None, "None", "none"): + return keep_mask + + # 计算序列级比率 + if self.pipeline_config.enable_seq_reject == "sequence": + log_ratio_agg = self._masked_sum(torch.log(ratio), response_mask, dim=-1) + elif self.pipeline_config.enable_seq_reject == "geometric": + log_ratio_agg = self._masked_mean(torch.log(ratio), response_mask, dim=-1) + else: + return keep_mask + + seq_ratio = torch.exp(log_ratio_agg) + + # 识别要拒绝的序列 + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + + # 更新掩码 + new_keep_mask = keep_mask & (~seq_reject) + + # 收集统计 + seq_has_valid = response_mask.any(dim=-1) + total_valid_seqs = seq_has_valid.sum().item() + + seq_reject_low = seq_too_low & seq_has_valid + seq_reject_high = seq_too_high & seq_has_valid + + seq_reject_low_frac = seq_reject_low.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + seq_reject_high_frac = seq_reject_high.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + self.stats.add("seq_reject_low_frac", seq_reject_low_frac) + self.stats.add("seq_reject_high_frac", seq_reject_high_frac) + + return new_keep_mask + + def _collect_rejection_stats( + self, + ratio: torch.Tensor, + raw_is_weight: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """收集拒绝相关的统计指标""" + + # 计算被拒绝的token + rejected_mask = response_mask & (~keep_mask) + self.stats.add("reject_frac", self._agg_loss(rejected_mask.float(), response_mask)) + + + # 仅在序列拒绝启用时计算序列级拒绝率 + if self.pipeline_config.enable_seq_reject not in (None, "None", "none"): + seq_has_valid = response_mask.any(dim=-1) + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + self.stats.add("seq_reject_frac", seq_reject_frac) + else: + # 未启用时显式设为0.0 + self.stats.add("seq_reject_frac", 0.0) + + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 使用已计算的rejected_mask + clipped_mask = ((raw_is_weight <= self.pipeline_config.infer_is_threshold_min) | + (raw_is_weight >= self.pipeline_config.infer_is_threshold_max)) & response_mask + + clip_and_reject_frac = self._agg_loss((clipped_mask & rejected_mask).float(), response_mask) + clip_or_reject_frac = self._agg_loss((clipped_mask | rejected_mask).float(), response_mask) + + self.stats.add("token_clip_and_reject_frac", clip_and_reject_frac) + self.stats.add("token_clip_or_reject_frac", clip_or_reject_frac) + else: + # 关键:为未启用IS的情况提供默认值 + self.stats.add("token_clip_and_reject_frac", 0.0) + self.stats.add("token_clip_or_reject_frac", 0.0) + + def _compute_kl_metrics( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """计算KL散度指标""" + # 原始KL(所有有效token) + inferkl_orig = self._compute_approx_kl(infer_log_probs, old_log_probs, response_mask, kl_penalty="kl") + inferkl_orig_agg = self._agg_loss(inferkl_orig, response_mask) + self.stats.add("inferkl", inferkl_orig_agg) + + # 拒绝后KL(仅保留的token) + inferkl_final = self._compute_approx_kl(infer_log_probs, old_log_probs, keep_mask, kl_penalty="kl") + inferkl_final_agg = self._agg_loss(inferkl_final, keep_mask) + self.stats.add("inferkl_reject", inferkl_final_agg) + + # --- 辅助方法(使用已有工具函数)--- + def _compute_approx_kl( + self, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: torch.Tensor, + kl_penalty: str = "kl" + ) -> torch.Tensor: + """使用已有的compute_approx_kl函数计算近似KL散度""" + from roll.utils.functionals import compute_approx_kl + return compute_approx_kl( + log_probs=log_probs, + log_probs_base=log_probs_base, + action_mask=action_mask, + kl_penalty=kl_penalty + ) + + def _agg_loss(self, loss_mat: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """使用已有的agg_loss函数聚合损失""" + from roll.utils.functionals import agg_loss + return agg_loss( + loss_mat=loss_mat, + loss_mask=loss_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ) + + def _masked_sum(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_sum函数在掩码区域求和""" + from roll.utils.functionals import masked_sum + return masked_sum(tensor, mask, dim=dim) + + def _masked_mean(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_mean函数在掩码区域计算均值""" + from roll.utils.functionals import masked_mean + return masked_mean(tensor, mask, dim=dim) \ No newline at end of file From c3d1121e10cd07cb7bd068ac3e4afc3c867cd533 Mon Sep 17 00:00:00 2001 From: millioniron Date: Mon, 8 Dec 2025 20:04:15 +0800 Subject: [PATCH 5/7] modified: roll/pipeline/agentic/env_manager/step_env_manager.py --- roll/pipeline/agentic/env_manager/step_env_manager.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 7db4d0df0..a1d7cbf28 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -79,7 +79,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) - infer_logprobs = [] if "infer_logprobs" in history: infer_logprobs = [0] * len(history["prompt_ids"]) + history["infer_logprobs"] @@ -88,7 +87,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): prompt_mask = torch.tensor(prompt_masks, dtype=torch.bool).unsqueeze(0) score_tensor = torch.tensor([0] * len(token_ids), dtype=torch.float).unsqueeze(0) score_tensor[0][-1] = history['reward'] - infer_logprobs = history["infer_logprobs"].flatten().unsqueeze(0) position_ids = attention_mask.cumsum(dim=-1) input_ids = pad_to_length(input_ids, length=self.pipeline_config.sequence_length, pad_value=self.tokenizer.pad_token_id) @@ -106,7 +104,6 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): "response_mask": response_mask, "prompt_mask": prompt_mask, "scores": score_tensor, - "infer_logprobs": infer_logprobs, }, batch_size=input_ids.shape[0]), non_tensor_batch={ From 515bf39e15de7c249cc0edaf87ead8a7c24bc087 Mon Sep 17 00:00:00 2001 From: millioniron Date: Mon, 8 Dec 2025 20:28:40 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E5=8E=BB=E6=8E=89=E5=8E=9F=E6=9C=AC?= =?UTF-8?q?=E5=AE=98=E6=96=B9=E7=9A=84train-infer=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- roll/pipeline/agentic/agentic_actor_worker.py | 31 ++++++++++++------- roll/pipeline/rlvr/rlvr_config.py | 16 ---------- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index 75510c675..e97fef5ba 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -14,6 +15,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): output_tensor: torch.Tensor, model.forward()的输出Tensor """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] @@ -29,8 +31,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): else: ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip @@ -41,11 +41,23 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -70,11 +82,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ) total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - } - pg_metrics = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -91,8 +98,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) + return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index ba51d62ef..d053e2e9a 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -149,22 +149,6 @@ class RLVRConfig(PPOConfig): importance_sampling: Literal["token", "seq"] = ( field(default="token", metadata={"help": "policy importance sampling"}) ) - use_rollout_importance_sampling_ratio: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level loss weight"}) - rollout_importance_sampling_ratio_upper_bound: float = field(default=1.2) - - train_infer_ratio_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level response mask"}) - train_infer_ratio_threshold_low: float = field(default=0.8) - train_infer_ratio_threshold_high: float = field(default=1.2) - train_infer_diff_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as token-level response mask"}) - train_infer_diff_threshold_low: float = field(default=-0.2) - train_infer_diff_threshold_high: float = field(default=0.2) - - train_infer_ratio_seq_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as sequence-level response mask"}) - train_infer_ratio_seq_threshold_low: float = field(default=0.8) - train_infer_ratio_seq_threshold_high: float = field(default=1.2) - train_infer_diff_seq_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as sequence-level response mask"}) - train_infer_diff_seq_threshold_low: float = field(default=-0.2) - train_infer_diff_seq_threshold_high: float = field(default=0.2) val_greedy: bool = field(default=False, metadata={"help": "Use greedy for validation"}) val_n_sample: int = field(default=1, metadata={"help": "Number of samples for validation"}) From 243a961ad43a907ccd8904445eca5a0d550f3175 Mon Sep 17 00:00:00 2001 From: millioniron Date: Mon, 8 Dec 2025 23:11:04 +0800 Subject: [PATCH 7/7] modified: roll/pipeline/agentic/env_manager/step_env_manager.py --- roll/pipeline/agentic/env_manager/step_env_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index a1d7cbf28..3eb8cefdf 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -79,6 +79,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) + infer_logprobs=[] if "infer_logprobs" in history: infer_logprobs = [0] * len(history["prompt_ids"]) + history["infer_logprobs"]