diff --git a/docs_roll/docs/User Guides/Configuration/infer_correction.md b/docs_roll/docs/User Guides/Configuration/infer_correction.md new file mode 100644 index 000000000..ecdce6325 --- /dev/null +++ b/docs_roll/docs/User Guides/Configuration/infer_correction.md @@ -0,0 +1,142 @@ +# 训推差异修复 + + +## 简介 + +训推差异是由于RL训练过程中,训练器和生成器之间由于后端不同(vLLM vs SGLang vs FSDP vs Megatron),精度不同(FP8 vs FP16 vs BF16 vs FP32),形成了一种类似off-policy gap,会导致训练不稳定和策略崩溃。 + + +## 实现原理 + + +修复训推差异大致可分为两种方法(1)对训练器和生成器进行策略修正(2)使用infer_log_probs直接代替old_log_probs(trainer)进行PPO ratio计算。第二种方案比较直接,我们着重说明第一种方法。 + +### 对训练器和生成器进行策略修正 + +### IS权重矫正 +通过对训练器(old_log_probs)和生成器(infer_log_prob)之间进行重要性采样矫正,弥合训推差异。与off-policy算法类似,IS权重矫正可以区分token级别和sequence级别,只能选择一个。 + +### MASK过滤掉不符合条件的样本 +与IS权重修正不同的是,此方法对于超过阈值的样本直接进行mask遮掩,过滤掉不符合的样本。涉及的方法有(1)token级别:过滤掉不符合条件的token(2)灾难性token:过滤掉出现灾难性严重偏差的token的句子样本(3)sequence级别:对sequence进行IS计算,过滤掉不符合的句子样本(4)sequence级别,使用几何平均来计算IS权重,但指标也更为敏感 + + +## 关键参数配置 + +生成器是否返回infer_log_probs + +GeneratingArguments: + +```yaml +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: fp16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 +``` +当logprobs:大于0时,返回infer_log_probs + + +------- + +参数的配置在PPOConfig和中,关键配置参数如下: + +```yaml +infer_correction: true + +infer_is_mode: token # 可选 token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: true +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: true +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: sequence 可选None sequence geometric +infer_seq_mask_threshold_min: 0.1 +infer_seq_mask_threshold_max: 10 +``` + + +### infer_correction +- **含义**:控制是否启用训推差异修复机制。若启用,系统将使用 `infer_log_probs` 对策略梯度进行矫正。 +- **默认值**:`false` + +### infer_is_mode +- **含义**:指定重要性采样(IS)权重的计算粒度。 +- **可选值**: + - `"token"`:每个 token 独立计算 IS 权重 + - `"sequence"`:基于整个 response 序列的 log-ratio 求和后广播至所有 token + - `"none"`(或未设置):不应用 IS 加权,权重恒为 1 +- **默认值**:若未设置,默认为 `"token"` +- **注意**:不可同时使用多种模式,仅能选择其一。 + +### infer_is_threshold_min +- **含义**:IS 权重的下限阈值,用于裁剪过小的权重以控制方差。 +- **默认值**:`0.0` +- **建议**:通常保留为 `0.0`,以保持无偏性下界 + +### infer_is_threshold_max +- **含义**:IS 权重的上限阈值,防止极端大的权重主导梯度。 +- **默认值**:`2.0` +- **建议**:`"token"`级别推荐为 `1.5 ~ 5.0` `"sequence"`级别推荐为2.0 - 10.0 + +### enable_token_reject +- **含义**:是否启用 token 级别的样本拒绝机制。 +- **默认值**:`false` +- **作用**:结合 `infer_token_mask_threshold_min/max`,mask 掉 IS ratio 超出合法区间的 token。 + +### infer_token_mask_threshold_min +- **含义**:token 级 IS ratio(`old_log_probs / infer_log_probs` 的指数)的下限。 +- **默认值**:`0.0` +- **典型值**:`0.0`通常可设为1/max + +### infer_token_mask_threshold_max +- **含义**:遮掩token 级 IS ratio 的上限。 +- **默认值**:`2.0` +- **典型范围**:`1.5 ~ 5.0` + +### enable_catastrophic_reject +- **含义**:是否启用“灾难性偏差”检测并拒绝整句样本。 +- **默认值**:`false` +- **触发条件**:只要序列中存在任一 token 满足 `ratio < infer_catastrophic_threshold`,则整句被 mask。 + +### infer_catastrophic_threshold +- **含义**:灾难性拒绝的 ratio 下限阈值。 +- **默认值**:`1e-4` +- **解释**:当 `infer_log_probs` 远大于 `old_log_probs`(即生成器过于“自信”),导致 `ratio = exp(old - infer)` 极小 + +### enable_seq_reject +- **含义**:是否启用序列级别的拒绝机制,以及使用何种聚合方式。 +- **可选值**: + - `null` / `false`:禁用 + - `"sequence"`:使用 log-ratio **求和** 计算序列 IS ratio + - `"geometric"`:使用 log-ratio **平均**(等价于几何平均概率)计算序列 IS ratio +- **默认值**:`null` + +### infer_seq_mask_threshold_min +- **含义**:遮掩序列级 IS ratio 的下限。 +- **默认值**:`0.1` +- **典型值**:通常可设为1/max,当使用`"geometric"`时,最好强制设为1/max + + +### infer_seq_mask_threshold_max +- **含义**:遮掩序列级 IS ratio 的上限。 +- **默认值**:`10.0` +- **典型范围**:当使用`"sequence"`时,推荐`2.0 ~ 10.0`,但随着长度增加可适当放宽。当使用`"geometric"`时,推荐设置为1.0001 - 1.001 + + + +## 使用建议 + +1. 通常情况下,old_log_prob << infer_log_porb, 阈值的下限就比较重要了。并不建议使用sequence级别的IS或MASK + diff --git a/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml new file mode 100644 index 000000000..78fefee47 --- /dev/null +++ b/examples/qwen2.5-infer_correction/agentic_webshop_infer_correction.yaml @@ -0,0 +1,183 @@ +defaults: + - ../config/traj_envs@_here_ + - ../config/deepspeed_zero@_here_ + - ../config/deepspeed_zero2@_here_ + - ../config/deepspeed_zero3@_here_ + - ../config/deepspeed_zero3_cpuoffload@_here_ + +hydra: + run: + dir: . + output_subdir: null + +exp_name: "agentic_pipeline_webshop_infer_correction" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +render_save_dir: ./output/render +system_envs: + USE_MODELSCOPE: '1' + +#track_with: wandb +#tracker_kwargs: +# api_key: +# project: roll-agentic +# name: ${exp_name}_webshop +# notes: "agentic_pipeline" +# tags: +# - agentic +# - roll +# - baseline + +#track_with: swanlab +#tracker_kwargs: +# login_kwargs: +# api_key: your_api_key +# project: roll-agentic +# logdir: debug +# experiment_name: ${exp_name} +# tags: +# - roll +# - agentic +# - debug + +track_with: tensorboard +tracker_kwargs: + log_dir: /data/oss_bucket_0/yali/llm/tensorboard/roll_exp/agentic_webshop + +num_gpus_per_node: 8 + +max_steps: 1024 +save_steps: 10000 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + +rollout_batch_size: 64 +val_batch_size: 64 +sequence_length: 8192 + +reward_clip: 20 +advantage_clip: 0.2 # 0.1-0.3 +ppo_epochs: 1 +adv_estimator: "grpo" +#pg_clip: 0.1 +max_grad_norm: 1.0 +#dual_clip_loss: True +init_kl_coef: 0.0 +whiten_advantages: true +entropy_loss_coef: 0 + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + +# infer correction + +infer_correction: true + +infer_is_mode: token #token sequence +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + + +enable_seq_reject: None # None sequence geometric +infer_seq_mask_threshold_min: 0.999 +infer_seq_mask_threshold_max: 1.001 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + + +actor_train: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 8 + warmup_steps: 10 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + context_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + max_grad_norm: ${max_grad_norm} + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: 1024 # single-turn response length + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: 1 + logprobs: 1 + data_args: + template: qwen2_5 + strategy_args: + strategy_name: sglang + strategy_config: + mem_fraction_static: 0.85 + load_format: auto + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reference: + model_args: + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2_5 + strategy_args: + strategy_name: hf_infer + strategy_config: ~ + device_mapping: list(range(0,8)) + infer_batch_size: 1 + +reward_normalization: + grouping: traj_group_id # 可以tags(env_type)/traj_group_id(group)/batch(rollout_batch)... group_by计算reward/adv + method: mean_std # asym_clip / identity / mean_std + +train_env_manager: + format_penalty: -0.05 + num_env_groups: 8 + group_size: 8 + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [8] # If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +val_env_manager: + num_env_groups: 64 + group_size: 1 # should be set to 1 because val temperature is set to 0 and same prompt leads to same output + max_env_num_per_worker: 1 # The max_env_num_per_worker must be set to 1 to avoid conflicts with the webshop simple server. + tags: [WebShopEnv] + num_groups_partition: [64] # TODO: If not set, all env names divide nums equally. Under the same group, the env config and env seed (prompt) are equal in each generation + +custom_envs: + WebShopEnv: + ${custom_env.WebShopEnv} \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml new file mode 100644 index 000000000..1bbe3c8e8 --- /dev/null +++ b/examples/qwen2.5-infer_correction/rlvr_infer_correction_config.yaml @@ -0,0 +1,265 @@ +hydra: + run: + dir: . + output_subdir: null + +exp_name: "qwen2.5-7B-rlvr-infer-correction-config" +seed: 42 +logging_dir: ./output/logs +output_dir: ./output +system_envs: + USE_MODELSCOPE: '1' + + +checkpoint_config: + type: file_system + output_dir: /data/cpfs_0/rl_examples/models/${exp_name} + +track_with: tensorboard +tracker_kwargs: + log_dir: ./rl_examples/llm/tensorboard/roll_exp/rlvr + +num_gpus_per_node: 8 + +max_steps: 500 +save_steps: 100 +logging_steps: 1 +eval_steps: 10 +resume_from_checkpoint: false + + +rollout_batch_size: 64 # prompt +prompt_length: 2048 +response_length: 4096 + +num_return_sequences_in_group: 8 +ppo_epochs: 1 +adv_estimator: "reinforce" + +# clip +value_clip: 0.5 +reward_clip: 10 +advantage_clip: 2.0 +dual_clip_loss: true + +# normalize +norm_mean_type: ~ +norm_std_type: ~ + +# data mask +max_len_mask: true +difficulty_mask: true +difficulty_low_threshold: 0.1 +difficulty_high_threshold: 0.95 +error_max_len_clip: false + +# data weight +difficulty_loss_weight: false +length_loss_weight: false + +# reward +add_token_level_kl: false + +# advantage +whiten_advantages: true + +# dynamic sampling scheduler +# use_additional_prompts: true +# max_running_requests: 256 +# is_num_return_sequences_expand: false + +pretrain: Qwen/Qwen2.5-7B-Instruct +reward_pretrain: Qwen/Qwen2.5-7B-Instruct + + +# infer correction +infer_correction: true + +infer_is_mode: token +infer_is_threshold_min: 0.0 +infer_is_threshold_max: 2.0 # 1.5~5.0 + +enable_token_reject: false +infer_token_mask_threshold_min: 0.0 +infer_token_mask_threshold_max: 2.0 # 2~10 + +enable_catastrophic_reject: false +infer_catastrophic_threshold: 1e-4 + +enable_seq_reject: None + +# enable_seq_reject: sequence +# infer_seq_mask_threshold_min: 0.1 +# infer_seq_mask_threshold_max: 10 + +# enable_seq_reject: geometric +# infer_seq_mask_threshold_min: 0.999 +# infer_seq_mask_threshold_max: 1.001 + +validation: + data_args: + template: qwen2.5 + file_name: + - data/math_benchmarks.jsonl + generating_args: + max_new_tokens: ${response_length} + top_p: 0.6 + top_k: 50 + num_beams: 1 + temperature: 0.6 + num_return_sequences: 1 + + +actor_train: + model_args: + disable_gradient_checkpointing: false + dtype: bf16 + model_type: ~ + training_args: + learning_rate: 1.0e-6 + weight_decay: 0 + per_device_train_batch_size: 1 + gradient_accumulation_steps: 32 + warmup_steps: 20 + num_train_epochs: 50 + data_args: + template: qwen2.5 + file_name: + - data/code_KodCode_data.jsonl + - data/llm_judge_Multi-subject-RLVR_deal_new.jsonl + - data/math_deepmath_deal.jsonl + - data/general_ifeval_train_deal.jsonl + - data/general_CrossThink-QA_deal.jsonl + domain_interleave_probs: + math_rule: 0.4 + code_sandbox: 0.3 + llm_judge: 0.1 + crossthinkqa: 0.1 + ifeval: 0.1 + dataset_dir: data + messages: messages + interleave_probs: "1.0" + preprocessing_num_workers: 16 + strategy_args: + strategy_name: megatron_train + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + use_distributed_optimizer: true + recompute_granularity: full + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +actor_infer: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + generating_args: + max_new_tokens: ${response_length} + top_p: 0.99 + top_k: 100 + num_beams: 1 + temperature: 0.99 + num_return_sequences: ${num_return_sequences_in_group} + logprobs: 1 + data_args: + template: qwen2.5 + strategy_args: + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + device_mapping: list(range(0,12)) + infer_batch_size: 1 + +reference: + model_args: + disable_gradient_checkpointing: true + dtype: bf16 + model_type: ~ + data_args: + template: qwen2.5 + strategy_args: + strategy_name: megatron_infer + strategy_config: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + expert_model_parallel_size: 1 + device_mapping: list(range(0,16)) + infer_batch_size: 4 + +rewards: + crossthinkqa: + worker_cls: roll.pipeline.rlvr.rewards.crossthinkqa_rule_reward_worker.CrossThinkQARuleRewardWorker + reward_type: soft + response_length_penalty_coef: 0.0 + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [crossthinkqa] + world_size: 8 + infer_batch_size: 4 + ifeval: + worker_cls: roll.pipeline.rlvr.rewards.ifeval_rule_reward_worker.GeneralRuleRewardWorker + reward_type: soft + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [ifeval] + world_size: 8 + infer_batch_size: 4 + math_rule: + worker_cls: roll.pipeline.rlvr.rewards.math_rule_reward_worker.MathRuleRewardWorker + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + tag_included: [deepmath_103k, aime] + world_size: 8 + infer_batch_size: 1 + code_sandbox: + use_local: true + worker_cls: roll.pipeline.rlvr.rewards.code_sandbox_reward_worker.CodeSandboxRewardWorker + tag_included: [KodCode] + model_args: + model_name_or_path: ${reward_pretrain} + data_args: + template: qwen2.5 + world_size: 8 + infer_batch_size: 1 + llm_judge: + # NOTE: llm as judge 也需要gpu, 不能和actor infer共享gpu + worker_cls: roll.pipeline.rlvr.rewards.llm_judge_reward_worker.LLMJudgeRewardWorker + judge_prompt: Qwen2.5-7B-Instruct-RLVR-prompt + judge_model_type: inference + tag_included: [RLVR] + model_args: + model_name_or_path: virtuoussy/Qwen2.5-7B-Instruct-RLVR + attn_implementation: fa2 + disable_gradient_checkpointing: true + dtype: bf16 + model_type: trl + generating_args: + max_new_tokens: 100 + top_p: 0.8 + top_k: 50 + num_beams: 1 + temperature: 0.8 + num_return_sequences: 1 + data_args: + template: qwen2.5 + strategy_args: + # strategy_name: hf_infer + # strategy_config: null + strategy_name: vllm + strategy_config: + gpu_memory_utilization: 0.8 + block_size: 16 + max_model_len: 8000 + load_format: auto + device_mapping: list(range(12,16)) + infer_batch_size: 4 \ No newline at end of file diff --git a/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh new file mode 100644 index 000000000..ff1a59402 --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_agentic_pipeline_webshop.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Run `git submodule update --init --recursive` to init submodules before run this script. +set +x + + +pip install -r third_party/webshop-minimal/requirements.txt --trusted-host mirrors.aliyun.com --index-url https://mirrors.aliyun.com/pypi/simple/ +python -m spacy download en_core_web_sm + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_agentic_pipeline.py --config_path $CONFIG_PATH --config_name agentic_webshop_infer_correction diff --git a/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh new file mode 100644 index 000000000..6874f4f5e --- /dev/null +++ b/examples/qwen2.5-infer_correction/run_rlvr_pipeline.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set +x + +CONFIG_PATH=$(basename $(dirname $0)) +python examples/start_rlvr_pipeline.py --config_path $CONFIG_PATH --config_name rlvr_infer_correction_config \ No newline at end of file diff --git a/roll/configs/base_config.py b/roll/configs/base_config.py index aa1da4c27..a6ea37f09 100644 --- a/roll/configs/base_config.py +++ b/roll/configs/base_config.py @@ -395,12 +395,68 @@ class PPOConfig(BaseConfig): field(default="seq-mean-token-mean", metadata={"help": "Loss aggregation mode"}) ) dual_clip_loss: bool = field(default=False, metadata={"help": "Use dual clip loss"}) + enable_reference: bool = field( default=False, metadata={"help": "Whether to enable reference cluster for computing ref_log_probs."} ) enable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Enable old_logprobs computation optimization for disable caching"}) force_disable_old_logprobs_recompute: bool = field(default=False, metadata={"help": "Force disable old_logprobs computation optimization for disable caching, priority is higher than enable_old_logprobs_recompute"}) + # trainer&rollout mismatch + infer_correction: bool = field( + default=False, + metadata={"help": "Whether to apply importance sampling correction during inference."} + ) + infer_is_mode: Literal["token", "sequence", "None"] = field( + default="None", + metadata={"help": "IS weighting mode: 'token' (per-token ratio), 'sequence' (per-sequence ratio), 'None' (no IS weighting)."} + ) + # Clipping thresholds (used in IS weighting) + infer_is_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum threshold for IS weight clipping. Recommended 0.0 for unbiased estimation."} + ) + infer_is_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum threshold for IS weight clipping."} + ) + # Token-level rejection + enable_token_reject: bool = field( + default=False, + metadata={"help": "Enable token-level rejection based on IS ratio thresholds."} + ) + infer_token_mask_threshold_min: float = field( + default=0.0, + metadata={"help": "Minimum IS ratio threshold for token rejection."} + ) + infer_token_mask_threshold_max: float = field( + default=2.0, + metadata={"help": "Maximum IS ratio threshold for token rejection."} + ) + # Catastrophic rejection (reject entire sequence if any token ratio is too small) + enable_catastrophic_reject: bool = field( + default=False, + metadata={"help": "Enable catastrophic rejection: reject entire sequence if any valid token has IS ratio below threshold."} + ) + infer_catastrophic_threshold: float = field( + default=1e-4, + metadata={"help": "Threshold below which a token triggers catastrophic rejection of its sequence."} + ) + # Sequence-level rejection + enable_seq_reject: Optional[Literal["sequence", "geometric",'None']] = field( + default=None, + metadata={"help": "Enable sequence-level rejection: 'sequence' uses sum of log-ratios, 'geometric' uses mean. None disables."} + ) + infer_seq_mask_threshold_min: float = field( + default=0.1, + metadata={"help": "Minimum IS ratio threshold for sequence rejection."} + ) + infer_seq_mask_threshold_max: float = field( + default=10.0, + metadata={"help": "Maximum IS ratio threshold for sequence rejection (typically larger than token-level)."} + ) + + def __post_init__(self): super().__post_init__() diff --git a/roll/configs/generating_args.py b/roll/configs/generating_args.py index 68cf88d17..5d1f94f95 100644 --- a/roll/configs/generating_args.py +++ b/roll/configs/generating_args.py @@ -58,6 +58,10 @@ class GeneratingArguments: default=None, metadata={"help": "Whether to include the stop strings in output text."}, ) + logprobs: Optional[int] = field( + default=0, + metadata={"help": "Whether return infer log-prob."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/roll/pipeline/agentic/agentic_actor_worker.py b/roll/pipeline/agentic/agentic_actor_worker.py index 75510c675..e97fef5ba 100644 --- a/roll/pipeline/agentic/agentic_actor_worker.py +++ b/roll/pipeline/agentic/agentic_actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -14,6 +15,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): output_tensor: torch.Tensor, model.forward()的输出Tensor """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] @@ -29,8 +31,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): else: ratio = (log_probs - old_log_probs).exp() - train_infer_ratio = (log_probs - infer_log_probs).exp() - train_infer_diff = log_probs.exp() - infer_log_probs.exp() pg_clip_low = self.pipeline_config.pg_clip_low if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip pg_clip_high = self.pipeline_config.pg_clip_high if self.pipeline_config.use_pg_clip_range else self.pipeline_config.pg_clip @@ -41,11 +41,23 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -70,11 +82,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): ) total_loss = total_loss - entropy_loss * self.pipeline_config.entropy_loss_coef - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - } - pg_metrics = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -91,8 +98,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) + return total_loss, pg_metrics diff --git a/roll/pipeline/agentic/env_manager/step_env_manager.py b/roll/pipeline/agentic/env_manager/step_env_manager.py index 4348605a3..3eb8cefdf 100644 --- a/roll/pipeline/agentic/env_manager/step_env_manager.py +++ b/roll/pipeline/agentic/env_manager/step_env_manager.py @@ -79,7 +79,7 @@ def formulate_rollouts(self, rollout_cache: RolloutCache): input_ids =torch.tensor(token_ids, dtype=torch.long).unsqueeze(0) attention_mask = torch.tensor([1] * len(token_ids), dtype=torch.long).unsqueeze(0) response_mask = torch.tensor(response_masks, dtype=torch.bool).unsqueeze(0) - infer_logprobs = [] + infer_logprobs=[] if "infer_logprobs" in history: infer_logprobs = [0] * len(history["prompt_ids"]) + history["infer_logprobs"] diff --git a/roll/pipeline/agentic/env_manager/traj_env_manager.py b/roll/pipeline/agentic/env_manager/traj_env_manager.py index 88ab15d91..ddc041ae0 100644 --- a/roll/pipeline/agentic/env_manager/traj_env_manager.py +++ b/roll/pipeline/agentic/env_manager/traj_env_manager.py @@ -11,7 +11,7 @@ from omegaconf import DictConfig from tensordict import TensorDict from transformers import PreTrainedTokenizer - +import json from roll.pipeline.agentic.llm_proxy import create_llm_proxy, BaseLLMProxy from roll.pipeline.agentic.env_manager.base_env_manager import RolloutCache, BaseEnvManager from roll.utils.env_action_limiter import get_global_limiter diff --git a/roll/pipeline/base_worker.py b/roll/pipeline/base_worker.py index d5c84c120..e7cb65389 100644 --- a/roll/pipeline/base_worker.py +++ b/roll/pipeline/base_worker.py @@ -2,7 +2,7 @@ import threading import time from typing import Union, Optional, Dict - +import numpy import ray import torch from codetiming import Timer @@ -25,9 +25,11 @@ postprocess_generate, GenerateRequestType, agg_loss, + masked_sum, ) from roll.utils.offload_nccl import reload_process_groups from roll.utils.offload_states import OffloadStateType +from roll.utils.infer_correction import InferCorrectionHandler from roll.utils.dynamic_batching import make_mini_batch_iter_for_dynamic_batching from roll.platforms import current_platform @@ -307,9 +309,11 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): """ response_mask = data.batch["response_mask"][:, 1:].long() + final_response_mask = data.batch.get("final_response_mask", response_mask) + ref_log_probs = data.batch["ref_log_probs"] advantages = data.batch["advantages"] - + infer_log_probs = data.batch.get("infer_logprobs", None) log_probs = self.strategy.op_compute_log_probs( logits=output_tensor, input_ids=data.batch["input_ids"], attention_mask=data.batch["response_mask"] ) @@ -326,11 +330,24 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): dual_clip_loss = -torch.max(-pg_loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) pg_loss = torch.where(advantages < 0, dual_clip_loss, pg_loss) - pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + + pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) - kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=response_mask, + kl_loss = compute_approx_kl(log_probs=log_probs, log_probs_base=ref_log_probs, action_mask=final_response_mask, kl_penalty="k3") - kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) + kl_loss = agg_loss(loss_mat=kl_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode) approxkl = compute_approx_kl( log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="mse" @@ -372,6 +389,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/policykl": agg_loss(loss_mat=policykl, loss_mask=response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode).detach().item(), } + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_pg_worker.py b/roll/pipeline/rlvr/actor_pg_worker.py index 477438595..e83ed0a9b 100644 --- a/roll/pipeline/rlvr/actor_pg_worker.py +++ b/roll/pipeline/rlvr/actor_pg_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl from roll.pipeline.rlvr.actor_worker import ActorWorker +from roll.utils.infer_correction import InferCorrectionHandler class ActorPGWorker(ActorWorker): @@ -33,6 +34,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + old_log_probs = data.batch["old_log_probs"] + infer_log_probs = data.batch.get("infer_logprobs", None) advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -77,6 +80,18 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): else: raise ValueError(f"Unsupported pg_variant: {pg_variant}") + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + pg_loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=pg_loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() + weighted_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights) original_pg_loss = agg_loss(loss_mat=pg_loss, loss_mask=final_response_mask, @@ -127,6 +142,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): # 使用缓存的指标 pg_metrics = self._get_pg_metrics(data) + + pg_metrics.update(infer_stats) return total_loss, pg_metrics diff --git a/roll/pipeline/rlvr/actor_worker.py b/roll/pipeline/rlvr/actor_worker.py index 19d0c66de..da34d3a38 100644 --- a/roll/pipeline/rlvr/actor_worker.py +++ b/roll/pipeline/rlvr/actor_worker.py @@ -4,6 +4,7 @@ from roll.distributed.scheduler.protocol import DataProto from roll.pipeline.base_worker import ActorWorker as BaseActorWorker from roll.utils.functionals import masked_mean, agg_loss, compute_approx_kl +from roll.utils.infer_correction import InferCorrectionHandler class ActorWorker(BaseActorWorker): @@ -17,6 +18,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): response_mask = data.batch["response_mask"][:, 1:].long() final_response_mask = data.batch.get("final_response_mask", response_mask) ref_log_probs = data.batch["ref_log_probs"] + advantages = data.batch["advantages"] log_probs = self.strategy.op_compute_log_probs( @@ -53,34 +55,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): log_probs=log_probs, log_probs_base=old_log_probs, action_mask=response_mask, kl_penalty="kl" ) - train_infer_ratio = (old_log_probs - infer_log_probs).exp() - train_infer_diff = old_log_probs.exp() - infer_log_probs.exp() - train_infer_ratio_seq = masked_mean(old_log_probs - infer_log_probs, response_mask, dim=-1).exp().unsqueeze(-1).expand_as(train_infer_ratio) - train_infer_diff_seq = masked_mean(old_log_probs.exp() - infer_log_probs.exp(), response_mask, dim=-1).unsqueeze(-1).expand_as(train_infer_diff) - - train_infer_ratio_mask_mean = 1.0 - train_infer_diff_mask_mean = 1.0 - train_infer_ratio_seq_mask_mean = 1.0 - train_infer_diff_seq_mask_mean = 1.0 - - if self.pipeline_config.train_infer_ratio_mask: - train_infer_ratio_mask = (train_infer_ratio <= self.pipeline_config.train_infer_ratio_threshold_high).float() * (train_infer_ratio >= self.pipeline_config.train_infer_ratio_threshold_low).float() - train_infer_ratio_mask_mean = masked_mean(train_infer_ratio_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_mask - if self.pipeline_config.train_infer_diff_mask: - train_infer_diff_mask = (train_infer_diff <= self.pipeline_config.train_infer_diff_threshold_high).float() * (train_infer_diff >= self.pipeline_config.train_infer_diff_threshold_low).float() - train_infer_diff_mask_mean = masked_mean(train_infer_diff_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_mask - - if self.pipeline_config.train_infer_ratio_seq_mask: - train_infer_ratio_seq_mask = (train_infer_ratio_seq <= self.pipeline_config.train_infer_ratio_seq_threshold_high).float() * (train_infer_ratio_seq >= self.pipeline_config.train_infer_ratio_seq_threshold_low).float() - train_infer_ratio_seq_mask_mean = masked_mean(train_infer_ratio_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_ratio_seq_mask - if self.pipeline_config.train_infer_diff_seq_mask: - train_infer_diff_seq_mask = (train_infer_diff_seq <= self.pipeline_config.train_infer_diff_seq_threshold_high).float() * (train_infer_diff_seq >= self.pipeline_config.train_infer_diff_seq_threshold_low).float() - train_infer_diff_seq_mask_mean = masked_mean(train_infer_diff_seq_mask, final_response_mask, dim=-1).mean().detach().item() - final_response_mask = final_response_mask * train_infer_diff_seq_mask - if self.pipeline_config.importance_sampling == "token": ratio = (log_probs - old_log_probs).exp() elif self.pipeline_config.importance_sampling == "seq": @@ -98,11 +72,21 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): if self.pipeline_config.dual_clip_loss: dual_clip_loss = -torch.max(-loss, (1 + self.pipeline_config.pg_clip * 2) * advantages) loss = torch.where(advantages < 0, dual_clip_loss, loss) + + infer_stats = {} + if infer_log_probs is not None and self.pipeline_config.infer_correction: + correction_handler = InferCorrectionHandler(self.pipeline_config) + loss, infer_response_mask, infer_stats = correction_handler( + old_log_probs=old_log_probs, + infer_log_probs=infer_log_probs, + response_mask=response_mask, + pg_loss=loss + ) + # 更新最终掩码 + final_response_mask = (final_response_mask.bool() & infer_response_mask).long() - if self.pipeline_config.use_rollout_importance_sampling_ratio: - rollout_importance_sampling_clip = (train_infer_ratio > self.pipeline_config.rollout_importance_sampling_ratio_upper_bound).float() - loss = train_infer_ratio.clamp(0, self.pipeline_config.rollout_importance_sampling_ratio_upper_bound) * loss + weighted_pg_loss = agg_loss(loss_mat=loss, loss_mask=final_response_mask, loss_agg_mode=self.pipeline_config.loss_agg_mode, weights=sample_weights, loss_scale=loss_scale) @@ -150,15 +134,7 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): total_loss = total_loss + topr_neg_loss * self.pipeline_config.use_topr_neg_loss_coef metrics['actor/topr_neg_loss'] = topr_neg_loss.detach().item() - train_infer_prob_metric = { - "actor/train_infer_ratio_mean": masked_mean(train_infer_ratio, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_diff_mean": masked_mean(train_infer_diff, response_mask, dim=-1).mean().detach().item(), - "actor/train_infer_ratio_mask_mean": train_infer_ratio_mask_mean, - "actor/train_infer_diff_mask_mean": train_infer_diff_mask_mean, - "actor/train_infer_ratio_seq_mask_mean": train_infer_ratio_seq_mask_mean, - "actor/train_infer_diff_seq_mask_mean": train_infer_diff_seq_mask_mean, - } - + loss_metric = { "actor/ppo_ratio_high_clipfrac": clipped_high.mean().detach().item(), "actor/ppo_ratio_low_clipfrac": clipped_low.mean().detach().item(), @@ -170,9 +146,6 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): loss_agg_mode=self.pipeline_config.loss_agg_mode, loss_scale=loss_scale).detach().item(), } - if self.pipeline_config.use_rollout_importance_sampling_ratio: - loss_metric["actor/rollout_importance_sampling_clip"] = rollout_importance_sampling_clip.mean().detach().item() - pg_metrics = { "actor/pg_loss": original_pg_loss.detach().item(), "actor/weighted_pg_loss": weighted_pg_loss.detach().item(), @@ -190,9 +163,8 @@ def loss_func(self, data: DataProto, output_tensor: torch.Tensor): "actor/sample_weights_max": sample_weights.max().detach().item(), **metrics, **loss_metric, - **train_infer_prob_metric } - + pg_metrics.update(infer_stats) return total_loss, pg_metrics def compute_sample_weights(self, data: DataProto, response_mask: torch.Tensor): diff --git a/roll/pipeline/rlvr/rlvr_config.py b/roll/pipeline/rlvr/rlvr_config.py index ba51d62ef..d053e2e9a 100644 --- a/roll/pipeline/rlvr/rlvr_config.py +++ b/roll/pipeline/rlvr/rlvr_config.py @@ -149,22 +149,6 @@ class RLVRConfig(PPOConfig): importance_sampling: Literal["token", "seq"] = ( field(default="token", metadata={"help": "policy importance sampling"}) ) - use_rollout_importance_sampling_ratio: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level loss weight"}) - rollout_importance_sampling_ratio_upper_bound: float = field(default=1.2) - - train_infer_ratio_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as token-level response mask"}) - train_infer_ratio_threshold_low: float = field(default=0.8) - train_infer_ratio_threshold_high: float = field(default=1.2) - train_infer_diff_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as token-level response mask"}) - train_infer_diff_threshold_low: float = field(default=-0.2) - train_infer_diff_threshold_high: float = field(default=0.2) - - train_infer_ratio_seq_mask: bool = field(default=False, metadata={"help": "apply train/infer ratio as sequence-level response mask"}) - train_infer_ratio_seq_threshold_low: float = field(default=0.8) - train_infer_ratio_seq_threshold_high: float = field(default=1.2) - train_infer_diff_seq_mask: bool = field(default=False, metadata={"help": "apply train-infer diff as sequence-level response mask"}) - train_infer_diff_seq_threshold_low: float = field(default=-0.2) - train_infer_diff_seq_threshold_high: float = field(default=0.2) val_greedy: bool = field(default=False, metadata={"help": "Use greedy for validation"}) val_n_sample: int = field(default=1, metadata={"help": "Number of samples for validation"}) diff --git a/roll/utils/infer_correction.py b/roll/utils/infer_correction.py new file mode 100644 index 000000000..215203081 --- /dev/null +++ b/roll/utils/infer_correction.py @@ -0,0 +1,380 @@ +from typing import Literal, Optional, Tuple, Dict, Any +import torch + +class StatsCollector: + """统一收集诊断指标的类""" + def __init__(self, prefix: str = "infer_correction"): + self.prefix = prefix + self.stats: Dict[str, Any] = {} + self.tensor_stats: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} + + def add(self, name: str, value: Any): + """添加标量指标""" + self.stats[f"{self.prefix}/{name}"] = value.item() if torch.is_tensor(value) else value + + def add_tensor_stat(self, name: str, tensor: torch.Tensor, mask: torch.Tensor): + """添加张量统计指标(延迟计算)""" + self.tensor_stats[name] = (tensor, mask) + + def compute_tensor_stats(self): + """严格遵循原始代码的数据移动策略""" + for name, (tensor, mask) in self.tensor_stats.items(): + # 1. 确保在同一设备上 + if tensor.device != mask.device: + mask = mask.to(tensor.device) + + # 2. 直接在原始代码风格中计算:先筛选,再移动到CPU + mask=mask.bool() + valid = tensor[mask] + + # 3. 严格按照原始代码逻辑处理 + if valid.numel() > 0: + # 关键:先detach()再item(),确保在CPU上计算 + valid_cpu = valid.detach().cpu() + self.add(f"{name}_mean", valid_cpu.mean().item()) + self.add(f"{name}_std", valid_cpu.std(unbiased=False).item() if valid_cpu.numel() > 1 else 0.0) + self.add(f"{name}_min", valid_cpu.min().item()) + self.add(f"{name}_max", valid_cpu.max().item()) + else: + self.add(f"{name}_mean", 0.0) + self.add(f"{name}_std", 0.0) + self.add(f"{name}_min", 0.0) + self.add(f"{name}_max", 0.0) + + self.tensor_stats.clear() + + def get_metrics(self) -> Dict[str, float]: + """获取所有指标""" + return self.stats.copy() + +class InferCorrectionHandler: + """处理重要性采样校正和样本拒绝的核心类""" + def __init__(self, pipeline_config: "PPOConfig"): + self.pipeline_config = pipeline_config + self.stats = StatsCollector() + + def __call__( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor, + pg_loss: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, float]]: + """ + 主入口:执行重要性采样校正和样本拒绝 + + Args: + old_log_probs: 历史策略的log概率 [B, T] + infer_log_probs: 生成时策略的log概率 [B, T] + response_mask: 有效token掩码 [B, T] + pg_loss: 原始策略梯度损失 [B, T] + + Returns: + weighted_loss: 重加权后的损失 + final_mask: 最终保留的token掩码 + metrics: 诊断指标字典 + """ + # 1. 对齐形状 + infer_log_probs = self._align_shapes(old_log_probs, infer_log_probs, response_mask) + + # 2. 计算IS权重 + ratio, raw_is_weight, is_weight = self._compute_is_weights(old_log_probs, infer_log_probs, response_mask) + + # 3. 收集基础统计 + self._collect_base_stats(ratio, response_mask) + + # 4. 应用拒绝策略 + keep_mask = response_mask.clone() + keep_mask = self._apply_token_rejection(ratio, keep_mask) + keep_mask = self._apply_catastrophic_rejection(ratio, keep_mask, response_mask) + keep_mask = self._apply_sequence_rejection(ratio, keep_mask, response_mask) + + # 5. 计算拒绝统计 + self._collect_rejection_stats(ratio, raw_is_weight, keep_mask, response_mask) + + # 6. 重加权损失 + weighted_loss = pg_loss * is_weight + + # 7. 计算KL指标 + self._compute_kl_metrics(old_log_probs, infer_log_probs, keep_mask, response_mask) + + # 8. 批量计算张量统计 + self.stats.compute_tensor_stats() + + return weighted_loss, keep_mask, self.stats.get_metrics() + + def _align_shapes( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """对齐log概率张量形状""" + if infer_log_probs.shape[1] == old_log_probs.shape[1] + 1: + infer_log_probs = infer_log_probs[:, 1:] + + assert old_log_probs.shape == infer_log_probs.shape == response_mask.shape, ( + f"Shape mismatch: old_log_probs {old_log_probs.shape}, " + f"infer_log_probs {infer_log_probs.shape}, " + f"response_mask {response_mask.shape}" + ) + return infer_log_probs + + def _compute_is_weights( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + response_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 计算重要性采样权重 + + Returns: + ratio: 原始重要性比率 [B, T] + raw_is_weight: 未裁剪的IS权重 [B, T] + is_weight: 裁剪后的IS权重 [B, T] + """ + log_ratio = old_log_probs - infer_log_probs + ratio = torch.exp(log_ratio) + + if self.pipeline_config.infer_is_mode == "token": + raw_is_weight = ratio + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS:使用序列总log-ratio + log_ratio_sum = self._masked_sum(log_ratio, response_mask, dim=-1).unsqueeze(-1) + seq_ratio = torch.exp(log_ratio_sum) + raw_is_weight = seq_ratio.expand_as(ratio) + # 收集序列级统计 + self.stats.add_tensor_stat("seq_ratio", seq_ratio.squeeze(-1), torch.ones_like(seq_ratio.squeeze(-1), dtype=torch.bool)) + else: # "None" or any other value + raw_is_weight = torch.ones_like(ratio) + + # 裁剪IS权重 + is_weight = raw_is_weight.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ).detach() + + return ratio, raw_is_weight, is_weight + + def _collect_base_stats(self, ratio: torch.Tensor, response_mask: torch.Tensor): + """收集基础统计指标""" + self.stats.add_tensor_stat("token_ratio", ratio, response_mask) + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 1. 裁剪比例统计(现有代码) + clipped_low = ratio <= self.pipeline_config.infer_is_threshold_min + clipped_high = ratio >= self.pipeline_config.infer_is_threshold_max + clipped = clipped_low | clipped_high + self.stats.add("token_clip_low_frac", self._agg_loss(clipped_low.float(), response_mask)) + self.stats.add("token_clip_high_frac", self._agg_loss(clipped_high.float(), response_mask)) + self.stats.add("token_clip_frac", self._agg_loss(clipped.float(), response_mask)) + + # 2. 添加缺失的:裁剪后权重的分布统计 + if self.pipeline_config.infer_is_mode == "token": + # 重新计算裁剪后的权重 + is_weight = ratio.clamp( + min=self.pipeline_config.infer_is_threshold_min, + max=self.pipeline_config.infer_is_threshold_max + ) + # 添加缺失的统计 + self.stats.add_tensor_stat("token_is_weight", is_weight, response_mask) + + elif self.pipeline_config.infer_is_mode == "sequence": + # 序列级IS权重已在_compute_is_weights中添加 + pass + + def _apply_token_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor + ) -> torch.Tensor: + """应用token级拒绝策略""" + if not self.pipeline_config.enable_token_reject: + return keep_mask + + ratio_too_high = ratio > self.pipeline_config.infer_token_mask_threshold_max + ratio_too_low = ratio < self.pipeline_config.infer_token_mask_threshold_min + token_reject = ratio_too_high | ratio_too_low + + # 更新掩码:丢弃被拒绝的token + new_keep_mask = keep_mask & (~token_reject) + + # 收集统计 + self.stats.add("token_reject_low_frac", self._agg_loss(ratio_too_low.float(), keep_mask)) + self.stats.add("token_reject_high_frac", self._agg_loss(ratio_too_high.float(), keep_mask)) + + return new_keep_mask + + def _apply_catastrophic_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用灾难性拒绝策略""" + if not self.pipeline_config.enable_catastrophic_reject: + return keep_mask + + # 识别灾难性token + catastrophic = (ratio < self.pipeline_config.infer_catastrophic_threshold) & response_mask + + # 检查哪些序列包含灾难性token + seq_has_catastrophic = catastrophic.any(dim=-1, keepdim=True) + + # 更新掩码:丢弃包含灾难性token的整个序列 + new_keep_mask = keep_mask & (~seq_has_catastrophic) + + # 收集统计 + catastrophic_token_frac = self._agg_loss(catastrophic.float(), response_mask) + self.stats.add("catastrophic_token_frac", catastrophic_token_frac) + + # 计算包含灾难性token的序列比例 + seq_has_valid = response_mask.any(dim=-1) + seq_has_catastrophic_flat = catastrophic.any(dim=-1) & seq_has_valid + catastrophic_seq_frac = ( + seq_has_catastrophic_flat.sum().float() / seq_has_valid.sum().float() + if seq_has_valid.sum() > 0 else 0.0 + ) + self.stats.add("catastrophic_seq_frac", catastrophic_seq_frac) + + return new_keep_mask + + def _apply_sequence_rejection( + self, + ratio: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ) -> torch.Tensor: + """应用序列级拒绝策略""" + if self.pipeline_config.enable_seq_reject in (None, "None", "none"): + return keep_mask + + # 计算序列级比率 + if self.pipeline_config.enable_seq_reject == "sequence": + log_ratio_agg = self._masked_sum(torch.log(ratio), response_mask, dim=-1) + elif self.pipeline_config.enable_seq_reject == "geometric": + log_ratio_agg = self._masked_mean(torch.log(ratio), response_mask, dim=-1) + else: + return keep_mask + + seq_ratio = torch.exp(log_ratio_agg) + + # 识别要拒绝的序列 + seq_too_high = seq_ratio > self.pipeline_config.infer_seq_mask_threshold_max + seq_too_low = seq_ratio < self.pipeline_config.infer_seq_mask_threshold_min + seq_reject = (seq_too_high | seq_too_low).unsqueeze(-1) + + # 更新掩码 + new_keep_mask = keep_mask & (~seq_reject) + + # 收集统计 + seq_has_valid = response_mask.any(dim=-1) + total_valid_seqs = seq_has_valid.sum().item() + + seq_reject_low = seq_too_low & seq_has_valid + seq_reject_high = seq_too_high & seq_has_valid + + seq_reject_low_frac = seq_reject_low.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + seq_reject_high_frac = seq_reject_high.sum().item() / total_valid_seqs if total_valid_seqs > 0 else 0.0 + + self.stats.add("seq_reject_low_frac", seq_reject_low_frac) + self.stats.add("seq_reject_high_frac", seq_reject_high_frac) + + return new_keep_mask + + def _collect_rejection_stats( + self, + ratio: torch.Tensor, + raw_is_weight: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """收集拒绝相关的统计指标""" + + # 计算被拒绝的token + rejected_mask = response_mask & (~keep_mask) + self.stats.add("reject_frac", self._agg_loss(rejected_mask.float(), response_mask)) + + + # 仅在序列拒绝启用时计算序列级拒绝率 + if self.pipeline_config.enable_seq_reject not in (None, "None", "none"): + seq_has_valid = response_mask.any(dim=-1) + seq_completely_rejected = (~keep_mask).all(dim=-1) & seq_has_valid + total_valid_seqs = seq_has_valid.sum().item() + rejected_seqs = seq_completely_rejected.sum().item() + seq_reject_frac = rejected_seqs / total_valid_seqs if total_valid_seqs > 0 else 0.0 + self.stats.add("seq_reject_frac", seq_reject_frac) + else: + # 未启用时显式设为0.0 + self.stats.add("seq_reject_frac", 0.0) + + + if self.pipeline_config.infer_is_mode in ("token", "sequence"): + # 使用已计算的rejected_mask + clipped_mask = ((raw_is_weight <= self.pipeline_config.infer_is_threshold_min) | + (raw_is_weight >= self.pipeline_config.infer_is_threshold_max)) & response_mask + + clip_and_reject_frac = self._agg_loss((clipped_mask & rejected_mask).float(), response_mask) + clip_or_reject_frac = self._agg_loss((clipped_mask | rejected_mask).float(), response_mask) + + self.stats.add("token_clip_and_reject_frac", clip_and_reject_frac) + self.stats.add("token_clip_or_reject_frac", clip_or_reject_frac) + else: + # 关键:为未启用IS的情况提供默认值 + self.stats.add("token_clip_and_reject_frac", 0.0) + self.stats.add("token_clip_or_reject_frac", 0.0) + + def _compute_kl_metrics( + self, + old_log_probs: torch.Tensor, + infer_log_probs: torch.Tensor, + keep_mask: torch.Tensor, + response_mask: torch.Tensor + ): + """计算KL散度指标""" + # 原始KL(所有有效token) + inferkl_orig = self._compute_approx_kl(infer_log_probs, old_log_probs, response_mask, kl_penalty="kl") + inferkl_orig_agg = self._agg_loss(inferkl_orig, response_mask) + self.stats.add("inferkl", inferkl_orig_agg) + + # 拒绝后KL(仅保留的token) + inferkl_final = self._compute_approx_kl(infer_log_probs, old_log_probs, keep_mask, kl_penalty="kl") + inferkl_final_agg = self._agg_loss(inferkl_final, keep_mask) + self.stats.add("inferkl_reject", inferkl_final_agg) + + # --- 辅助方法(使用已有工具函数)--- + def _compute_approx_kl( + self, + log_probs: torch.Tensor, + log_probs_base: torch.Tensor, + action_mask: torch.Tensor, + kl_penalty: str = "kl" + ) -> torch.Tensor: + """使用已有的compute_approx_kl函数计算近似KL散度""" + from roll.utils.functionals import compute_approx_kl + return compute_approx_kl( + log_probs=log_probs, + log_probs_base=log_probs_base, + action_mask=action_mask, + kl_penalty=kl_penalty + ) + + def _agg_loss(self, loss_mat: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor: + """使用已有的agg_loss函数聚合损失""" + from roll.utils.functionals import agg_loss + return agg_loss( + loss_mat=loss_mat, + loss_mask=loss_mask, + loss_agg_mode=self.pipeline_config.loss_agg_mode + ) + + def _masked_sum(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_sum函数在掩码区域求和""" + from roll.utils.functionals import masked_sum + return masked_sum(tensor, mask, dim=dim) + + def _masked_mean(self, tensor: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: + """使用已有的masked_mean函数在掩码区域计算均值""" + from roll.utils.functionals import masked_mean + return masked_mean(tensor, mask, dim=dim) \ No newline at end of file