From f0dd05f2051ba0b5ef2e70ab0eac0865e2939535 Mon Sep 17 00:00:00 2001 From: MichalMraz Date: Sat, 21 Feb 2026 21:26:20 +0000 Subject: [PATCH] Fix GRPO tool mask alignment after tool-call retokenization --- tests/test_grpo_trainer.py | 104 ++++++++++++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 30 ++++++++--- 2 files changed, 128 insertions(+), 6 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b32506c2ca0..09c5d22ffa5 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2344,6 +2344,110 @@ def fake_generate(input_ids, **kwargs): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0", + strict=True, + ) + @require_jmespath + def test_training_with_tools_keeps_masks_aligned_when_retokenization_shortens_completion(self): + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train[:3]") + + def reward_func(completions, **kwargs): + return [0.0] * len(completions) + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=128, + max_steps=1, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + tools=[multiply_tool], + ) + + turn = {"value": 0} + + def fake_generate_single_turn(prompts): + del prompts + turn["value"] += 1 + if turn["value"] == 1: + # Initial batch: sample 0 is interpreted as a tool call, others are plain completions. + prompt_ids = [[100, 101], [110, 111], [120, 121]] + completion_ids = [ + [900, 901, 902, 903], + [700, 701, 702], + [710, 711, 712], + ] + return prompt_ids, completion_ids, None, {} + if turn["value"] == 2: + # Tool round for the single tool-calling sample: + # prompt+completion+tool is intentionally short to trigger the edge case where + # completion_tool_length can be shorter than the previous completion length. + prompt_ids = [[100, 101, 201, 202]] + completion_ids = [[600, 601]] + return prompt_ids, completion_ids, None, {} + raise RuntimeError(f"Unexpected fake generation turn: {turn['value']}") + + def fake_parse_response(processing_class, ids): + del processing_class + if ids and ids[0] == 900: + return { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "multiply_tool", + "arguments": {"a": 3, "b": 4}, + }, + } + ], + } + return {"role": "assistant", "content": "ok"} + + def fake_get_per_token_logps_and_entropies( + model, + input_ids, + attention_mask, + logits_to_keep, + *args, + compute_entropy=False, + **kwargs, + ): + del model, attention_mask, args, kwargs + shape = (input_ids.size(0), logits_to_keep) + logps = torch.zeros(shape, dtype=torch.float32, device=input_ids.device, requires_grad=True) + entropies = torch.zeros(shape, dtype=torch.float32, device=input_ids.device) if compute_entropy else None + return logps, entropies + + observed_lengths = {} + original_tool_call_loop = trainer._tool_call_loop + + def wrapped_tool_call_loop(*args, **kwargs): + out = original_tool_call_loop(*args, **kwargs) + tool_mask, _completions, completion_ids, _logprobs, *_ = out + observed_lengths["completion"] = [len(ids) for ids in completion_ids] + observed_lengths["tool"] = [len(mask) for mask in tool_mask] + return out + + with ( + patch.object(trainer, "_generate_single_turn", side_effect=fake_generate_single_turn), + patch("trl.trainer.grpo_trainer.parse_response", side_effect=fake_parse_response), + patch.object(trainer, "_get_per_token_logps_and_entropies", side_effect=fake_get_per_token_logps_and_entropies), + patch.object(trainer, "_tool_call_loop", side_effect=wrapped_tool_call_loop), + ): + trainer.train() + + assert observed_lengths["completion"] == observed_lengths["tool"] + def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 87e3dad18cd..b580d3494ee 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1414,9 +1414,16 @@ async def _run_async_tools(async_coros): prompt_length = len(prompt_ids[idx_with_tool]) ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] completion_ids[idx_with_tool] = ct - tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) + # Keep tool_mask aligned with completion_ids even if ct becomes shorter. + if len(tool_mask[idx_with_tool]) > len(ct): + tool_mask[idx_with_tool] = tool_mask[idx_with_tool][: len(ct)] + else: + tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: - logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) + if len(logprobs[idx_with_tool]) > len(ct): + logprobs[idx_with_tool] = logprobs[idx_with_tool][: len(ct)] + else: + logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) # Keep only non-overlong items for further processing idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] @@ -1459,12 +1466,23 @@ async def _run_async_tools(async_coros): idx_with_tool = idxs_with_tool[idx] prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) prompt_length = len(prompt_ids[idx_with_tool]) - completion_length = len(completion_ids[idx_with_tool]) post_tool_length = len(post_tool_ids[idx]) - tool_length = prompt_completion_tool_length - prompt_length - completion_length - tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + + # Build mask for completion_tool_ids (tool tokens should be 0), then append post-tool model tokens (1). + completion_tool_length = prompt_completion_tool_length - prompt_length + if len(tool_mask[idx_with_tool]) > completion_tool_length: + tool_mask[idx_with_tool] = tool_mask[idx_with_tool][:completion_tool_length] + else: + tool_mask[idx_with_tool] += [0] * (completion_tool_length - len(tool_mask[idx_with_tool])) + tool_mask[idx_with_tool] += [1] * post_tool_length + if logprobs is not None: - logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] + # Tool-result tokens are external and have no model logprob, so they get 0.0. + if len(logprobs[idx_with_tool]) > completion_tool_length: + logprobs[idx_with_tool] = logprobs[idx_with_tool][:completion_tool_length] + else: + logprobs[idx_with_tool] += [0.0] * (completion_tool_length - len(logprobs[idx_with_tool])) + logprobs[idx_with_tool] += post_tool_logprobs[idx] # Update completion_ids with the new completions (after tool execution) for idx in range(len(idxs_with_tool)):