Fix GRPO tool mask alignment after tool-call retokenization#5145
Fix GRPO tool mask alignment after tool-call retokenization#5145MichalMraz wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
shouldn't happen, what model do you use? |
|
I originally hit it with Qwen3-32B on transformers==5.0.0. |
|
ok, can you share what are the generated completion ids / tools calls when if occurs? |
|
For example this kind of model response causes it The completion_ids are then Here
So old logic computes negative delta (41 - 47 = -6). Yes, it is a somewhat degenerate case but it sometimes happens during unstable training, causing a crash |
What does this PR do?
Fixes #5144
This PR fixes a shape-mismatch bug in
GRPOTrainertool-call flow.Root cause:
_tool_call_loop, tool-round retokenization can make the completion part shorter than the previous completion.tool_mask/logprobswere only extended, not truncated, so they could become longer thancompletion_ids._compute_losswhen multiplyingcompletion_mask * tool_mask.Fix:
trl/trainer/grpo_trainer.py, aligntool_maskandlogprobsto computed completion lengths by truncating or padding as needed before appending post-tool tokens.Tests:
tests/test_grpo_trainer.py::test_training_with_tools_keeps_masks_aligned_when_retokenization_shortens_completion