Skip to content

Fix GRPO tool mask alignment after tool-call retokenization#5145

Open
MichalMraz wants to merge 1 commit intohuggingface:mainfrom
MichalMraz:mmraz-fix
Open

Fix GRPO tool mask alignment after tool-call retokenization#5145
MichalMraz wants to merge 1 commit intohuggingface:mainfrom
MichalMraz:mmraz-fix

Conversation

@MichalMraz
Copy link

What does this PR do?

Fixes #5144

This PR fixes a shape-mismatch bug in GRPOTrainer tool-call flow.

Root cause:

  • During _tool_call_loop, tool-round retokenization can make the completion part shorter than the previous completion.
  • tool_mask/logprobs were only extended, not truncated, so they could become longer than completion_ids.
  • This later crashes in _compute_loss when multiplying completion_mask * tool_mask.

Fix:

  • In trl/trainer/grpo_trainer.py, align tool_mask and logprobs to computed completion lengths by truncating or padding as needed before appending post-tool tokens.

Tests:

  • Added regression test:
    tests/test_grpo_trainer.py::test_training_with_tools_keeps_masks_aligned_when_retokenization_shortens_completion
  • This test fails on old code with the tensor-size mismatch and passes with this patch.

@chatgpt-codex-connector
Copy link

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

@qgallouedec
Copy link
Member

During _tool_call_loop, tool-round retokenization can make the completion part shorter than the previous completion.

shouldn't happen, what model do you use?

@MichalMraz
Copy link
Author

I originally hit it with Qwen3-32B on transformers==5.0.0.
The test uses a tiny Qwen model only

@qgallouedec
Copy link
Member

qgallouedec commented Feb 21, 2026

ok, can you share what are the generated completion ids / tools calls when if occurs?

@MichalMraz
Copy link
Author

For example this kind of model response causes it

assistant_text = (
        "<tool_call>\n"
        "{\n"
        '                                  "name"        :                 "multiply_tool",\n'
        '                                  "arguments"   :                 {\n'
        '                                                    "a"   :  3,\n'
        '                                                    "b"   :  4\n'
        "                                                }\n"
        "}\n"
        f"</tool_call>{tokenizer.eos_token}"
    )

The completion_ids are then
[151657, 198, 515, 6656, 330, 606, 1, 286, 549, 338, 330, 64648, 22785, 756, 6656, 330, 16370, 1, 256, 549, 338, 341, 6374, 330, 64, 1, 256, 549, 220, 220, 18, 345, 6374, 330, 65, 1, 256, 549, 220, 220, 19, 198, 4569, 456, 532, 151658, 151645]
and parsed tool_calls
[{'type': 'function', 'function': {'name': 'multiply_tool', 'arguments': {'a': 3, 'b': 4}}}]

Here

  • original completion length = 47
  • retokenized completion_tool length = 41

So old logic computes negative delta (41 - 47 = -6).

Yes, it is a somewhat degenerate case but it sometimes happens during unstable training, causing a crash

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GRPOTrainer tool_mask can become longer than completion_ids after tool-call retokenization

2 participants